Created
March 3, 2026 12:24
-
-
Save brusangues/a01126e8fc4d7e9dcbca8688cfaebdba to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "bee24b83", | |
| "metadata": {}, | |
| "source": [ | |
| "# Construindo uma rede neural apenas com Numpy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3879c3b2", | |
| "metadata": {}, | |
| "source": [ | |
| "https://www.youtube.com/watch?v=w8yWXqWQYmU \n", | |
| "\n", | |
| "https://www.kaggle.com/code/wwsalmon/simple-mnist-nn-from-scratch-numpy-no-tf-keras " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "42a04b14", | |
| "metadata": {}, | |
| "source": [ | |
| "# 1. Dataset MNIST Dígitos" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "c304ddec", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1.1. Leitura dos dados" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "25e3e55e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# import pandas as pd\n", | |
| "# from sklearn.datasets import fetch_openml\n", | |
| "\n", | |
| "# mnist = fetch_openml('mnist_784', version=1)\n", | |
| "# X, y = mnist['data'], mnist['target']\n", | |
| "# X[\"target\"] = y\n", | |
| "# df = pd.DataFrame(X)\n", | |
| "# print(X.shape, df.shape)\n", | |
| "# df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "d98f0cae", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# len(df)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "85dd10a2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# # Salvando dados\n", | |
| "# step = 5_000\n", | |
| "# len_df = len(df)\n", | |
| "# for i in range(0, len_df, step):\n", | |
| "# print(i, i+step)\n", | |
| "# dfi = df.iloc[i:i+step]\n", | |
| "# dfi.to_csv(f\"mnist/{i}.csv\", index=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "65f3601e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 5000\n", | |
| "5000 10000\n", | |
| "10000 15000\n", | |
| "15000 20000\n", | |
| "20000 25000\n", | |
| "25000 30000\n", | |
| "30000 35000\n", | |
| "35000 40000\n", | |
| "40000 45000\n", | |
| "45000 50000\n", | |
| "50000 55000\n", | |
| "55000 60000\n", | |
| "60000 65000\n", | |
| "65000 70000\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Lendo dados\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "dfs = []\n", | |
| "step = 5_000\n", | |
| "len_df = 70000\n", | |
| "for i in range(0, len_df, step):\n", | |
| " print(i, i+step)\n", | |
| " dfi = pd.read_csv(f\"mnist/{i}.csv\")\n", | |
| " dfs.append(dfi)\n", | |
| "df_ = pd.concat(dfs, ignore_index=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "7ac4423a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "<class 'pandas.core.frame.DataFrame'>\n", | |
| "RangeIndex: 70000 entries, 0 to 69999\n", | |
| "Columns: 785 entries, pixel1 to target\n", | |
| "dtypes: int64(785)\n", | |
| "memory usage: 419.2 MB\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "df_.info()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "c47cd01c", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1.2. Visualização" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "c7456014", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = df_.filter(regex=\"pixel\")\n", | |
| "y = df_[\"target\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "8c8b02ce", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "<class 'pandas.core.frame.DataFrame'>\n", | |
| "RangeIndex: 70000 entries, 0 to 69999\n", | |
| "Columns: 784 entries, pixel1 to pixel784\n", | |
| "dtypes: int16(784)\n", | |
| "memory usage: 104.7 MB\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "X = X.astype(\"int16\")\n", | |
| "X.info()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "ad1e952b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "count 70000.000000\n", | |
| "mean 13.230286\n", | |
| "std 50.548067\n", | |
| "min 0.000000\n", | |
| "25% 0.000000\n", | |
| "50% 0.000000\n", | |
| "75% 0.000000\n", | |
| "max 255.000000\n", | |
| "Name: pixel100, dtype: float64" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X.pixel100.describe()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "a0016645", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "count 70000.000000\n", | |
| "mean 13.230286\n", | |
| "std 50.548067\n", | |
| "min 0.000000\n", | |
| "25% 0.000000\n", | |
| "50% 0.000000\n", | |
| "75% 0.000000\n", | |
| "max 255.000000\n", | |
| "Name: pixel100, dtype: float64" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X.pixel100.describe()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "5ea126bc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Text(0.5, 1.0, 'Id: 56982, Label: 5')" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAI2JJREFUeJzt3Qt0FOX5x/En3MI1wRDIhZtcxXItiIgiIFACKgqiFbUtWA8UBBVQ0PiXi1YNYkUqB8GeWoIHREEFKrVBCBK0ggoa8QaHYJBgICI2CYRwEeZ/nvec3WZDAs6yybvZ/X7OGdednXd3dnaYX+Z933knwnEcRwAAqGTVKvsDAQAggAAA1nAGBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAgiVasyYMXLppZey1S2bPXu2REREyI8//hiw9+S3hVsEEC5aamqqOZht3769wremhpd+Vulp/PjxZS6/ceNGGTBggERHR0uDBg2kR48e8vrrr/ssc+zYMZk8ebI0a9ZMIiMj5fLLL5dFixaV+X47duyQG2+8UeLj46V+/frSpUsXeeGFF+TMmTPeZY4cOSLPPvus9O3bVxo3biwNGzaUq6666pzPdat///7SqVMnCVVuf1tUfTVsrwDgVrdu3eTBBx/0mde+fftzlluyZIncc8898pvf/EaefvppqV69uuzevVtycnK8y2hwJCUlmfCcOHGitGvXTtavXy/33nuv/Pe//5VHH33UJ3yuvvpqs8zDDz8sdevWlX//+9/ywAMPyN69e+Wvf/2rWW7r1q3yf//3f3L99dfLY489JjVq1JA333xTRo0aJV9//bU8/vjj/OgX+dsiROhgpMDFWLJkiQ5o63zyyScXXHb06NFOy5Yt/f4sLXvDDTdccLns7GynTp06zv3333/e5VauXGnW/eWXX/aZP3LkSKd27dpOXl6ed97YsWOdWrVqOUeOHPFZtm/fvk5UVJT3+bfffuvs27fPZ5mzZ886AwYMcCIjI51jx445/ujXr5/TsWNHJxBmzZplvvfhw4edQKms3xahgyo4VJg1a9aYKqPatWubx9WrV5e53MGDB2XXrl1y+vTpX/zep06dkqKionJfX7x4sTm7eeKJJ7zVbGUN/P7++++bRz07KUmfnzhxQtauXeudV1hYaL6LVqmVlJCQIHXq1PE+b9WqlbRs2dJnGa1KGj58uJw8eVK+/fZbqSg7d+40bTGtW7c266pVhX/84x9NtWBZtA3ot7/9rURFRUmjRo3M2Zx+79KWLVtmqi/1e8bExJjtU/JMsjwV8dsidBBAqBDvvvuujBw50hx4U1JSzMH37rvvLrOdKDk52bS7fP/997/ovTdt2mSqv7QNRtsNPFVfpdt+OnToIO+8845p29H2Hz3AzpgxQ86ePetdTgNBq+Zq1arlU17f31PtVrINRkPoT3/6k3zzzTfy3XffmaB76623zHe4kEOHDpnH2NhYqSgbNmwwAafbesGCBSYoXnvtNVMdWFYAa/ho4OhvpMtoe9a4ceN8lnnqqafkD3/4g6l6nDdvnmkvS09PN21c+fn5512fivhtEUJsn4IhNKvgunXr5iQkJDj5+fneee+++65ZrnQ1jVbd6HytNruQYcOGOc8884yzZs0aU2127bXXmrLTp0/3WU6rxC655BJT5TVjxgznjTfecO68806z7COPPOJd7rnnnjPz3n//fZ/yuozOv/HGG73zfv75Z2fSpElOzZo1zWs6Va9e3Vm0aNEF11ur7Zo0aWLW11+/pAru+PHj58xbsWKFWdctW7acUwV30003+Sx77733mvmff/65ea5Vifodn3rqKZ/lvvjiC6dGjRo+88uqgquI3xahgwBCwAMoNzf3nAO9x69+9auLaicoTdtWkpKSzMEwJyfHO79atWpmHebMmeOz/JAhQ0zbUGFhoXl+8OBBJzo62mnXrp0JSD1QvvTSSybAtPzAgQN9yj///PMmlJYuXeq8/vrrzvDhw81nr169utx1PHPmjPlcbT/KzMz0+7u6bQMqLi42bTz6nfS7zJ8//5wAWr9+vU+Zb775xsxPSUkxz+fNm+dEREQ4e/bsMe9Vcrr88sudQYMGBawN6Jf+tggdVMEh4LRqSmmVTWmXXXZZQD9Lq/imTJkiP//8s2zevNk739Mmc8cdd/gsr8+Li4vls88+M8+1jeSf//ynqYobPHiwab+ZNm2aqb5SWhXkMWfOHHnmmWdkxYoVpkpKq6+0XatPnz6mB52uQ1nuu+8+SUtLk7///e/StWtXqUg//fSTaceJi4sz20C7get3UgUFBecsX/o3atOmjVSrVk327dtnnu/Zs8dU3ely+l4lJ62G/OGHHyrsu5T32yJ00A0bVV7z5s29B1+PxMREc/DUA3FJTZo0MY/axdpD2zK03eSLL74wjd8aErm5ued0AX7xxRfNNUUlQ0nddNNNMnXqVHPQbtu2rc9r2uVay2l4/f73v5eKpqH44YcfmhDVLs26rtrmNWTIEJ+2r/Md9EvSMjpPu5trW1lppbdFZfy2CB0EEALO0wNMA6A0vQ4n0Dy9yvSvcg/tsaWfr43f2iPMwxMsJZdVenDVA3bJTgxq0KBB3nl5eXk+F5x6eHp4lT4DWrhwoRlxQBvt9bqhiqahqp0DNPRmzpzpnV/W71DyNc8ZksrKyjKh4xmtQs+I9AxIl7FxPU5Zvy1CB1VwCDjtlqwH86VLl/pU+2gPLb0Q09+uuvpXcOkA0DJ6dqG92K677jrv/Ntvv908vvzyy955emDVi1O1G7EGVHkOHz5sqtp0lIOSAaQHYP0OJbs06/qsXLnS9LLTg7WHjnpw//33y1133WV6jlUGzxlK6d5u8+fPL7eMhmRJnqrHoUOHmsdbbrnFvK+GWun31eflde+uyN8WoYMzIFQI7dZ7ww03mPYRvQ5FDzB6cOvYsaO5Jqd0V10Nq+zs7POOE6dtNU8++aTceuut5i9yfc9XX31VvvzySzPSgbbneNx8880ycOBAsx56rYtWq+l1SR988IG89NJLZsgdj379+knv3r1N9Zl2lf7b3/5m1nHdunWmPcTjkUcekd/97nfSq1cv01VZ21i0PUi7aut61axZ0yz38ccfmzYi7fat67B8+XKf76GjKZQ8K9MqLl2HX9LOoeGon1Wabg8NO61OnDt3rjl4N23a1HSH1+1aHn1NqxC1ik5HcNDrfe68805vW5WGqn6e/kZaxajd6TVstZy2f+l2eOihh8p9/4r4bRFCbPeCQOiOhPDmm2+anlLaFVp7v7311lsX1VV3+/btpqtu06ZNTY+y+vXrO3369DGjGZTl6NGjzgMPPODEx8eb5Tt37uwsW7bsnOWmTJnitG7d2qxn48aNTXftvXv3lvmeaWlppjdabGys9z0XL15c5vYob9LXS66jzhs1apRzIfq55b2np7fegQMHnBEjRjgNGzY0vftuu+02b69E7flWuhfc119/7dx6661OgwYNTLd17WauvedK099St3W9evXM1KFDB2fixInO7t27vctU5m+L0BCh/7EdgkC40gtldXDTzz//XDp37mx7dYBKRRsQYNF7771nRisgfBCOOAMCAFjBGRAAwAoCCABgBQEEALCCAAIAWBF0F6Lq1eo6XIpe7FZ6XCoAQPDTq3uOHj1qxmQseTF30AeQho9nAEIAQNWld83VG0JWmSo4PfMBAFR9FzqeV1gA6SCHOvaT3pdex87S8bF+CardACA0XOh4XiEBpCMB6/1RZs2aJZ9++qkZ2DApKalCb14FAKhiKmKAuSuvvNIMVFjylsSJiYne2/yeT0FBwXkHcmRiG7APsA+wD0iV2AZ6PD+fgJ8BnTp1ygxPX/I+KtoLQp/rcO+l6a2QCwsLfSYAQOgLeADpvVf0xlKlb4Wsz/VeK6Xp/Vqio6O9Ez3gACA8WO8Fpzes0rtmeibttgcACH0Bvw4oNjbW3MI3Ly/PZ74+L+uuhnpnypJ3pwQAhIeAnwHp/dt79Ogh6enpPqMb6HO97TEAABU2EoJ2wR49erRcccUVcuWVV8r8+fOlqKhI7r77brY6AKDiAuj222+Xw4cPy8yZM03Hg27duklaWto5HRMAAOEr6O6Iqt2wtTccAKBq045lUVFRwdsLDgAQngggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAEEAAgPDBGRAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsqGHnY4Ffrl69eq4315w5c/zaxN26dXNdJjc313WZ2267zXWZiIgI12WeffZZ8UezZs1cl/n1r3/tuswbb7zhuszy5ctdl9m1a5frMqh4nAEBAKwggAAAoRFAs2fPNlUFJacOHToE+mMAAFVchbQBdezYUTZu3Pi/D6lBUxMAwFeFJIMGTnx8fEW8NQAgRFRIG9CePXskMTFRWrduLXfddZfs37+/3GVPnjwphYWFPhMAIPQFPIB69eolqampkpaWJosWLZLs7Gy59tpr5ejRo2Uun5KSItHR0d6pefPmgV4lAEA4BNDQoUPNNQ5dunSRpKQkeeeddyQ/P19WrlxZ5vLJyclSUFDgnXJycgK9SgCAIFThvQMaNmwo7du3l6ysrDJfj4yMNBMAILxU+HVAx44dk71790pCQkJFfxQAIJwD6KGHHpKMjAzZt2+ffPjhhzJixAipXr263HHHHYH+KABAFRbwKrgDBw6YsDly5Ig0btxY+vTpI9u2bTP/DwCAR4TjOI4EEe2Grb3hAI8pU6a43hh/+ctfQm4D+jMYaZD98w6Ikhe5/1LaIQqVTzuWRUVFlfs6Y8EBAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAIIAAAOGDMyAAgBUEEADACgIIAGAFAQQACM0b0gEXS++u69ayZcv8+qxhw4a5LlNcXOy6zD/+8Q/XZXJzc6WydO/e3XWZtm3b+jVYpVuhONBsuOIMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFZEOI7jSBApLCyU6Oho26uBKq5WrVp+ldu6davrMk888YTrMmvXrnVdBqhqdLTzqKiocl/nDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArKhh52OBitW2bVu/ynXr1s11meLiYr8+Cwh3nAEBAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUMRoqQNH/+fNurAOACOAMCAFhBAAEAqkYAbdmyRYYNGyaJiYkSEREha9as8XndcRyZOXOmJCQkSJ06dWTQoEGyZ8+eQK4zACAcA6ioqEi6du0qCxcuLPP1uXPnygsvvCCLFy+Wjz76SOrVqydJSUly4sSJQKwvACBcOyEMHTrUTGXRsx9t/H3sscfk5ptvNvNeeeUViYuLM2dKo0aNuvg1BgCEhIC2AWVnZ8uhQ4dMtZtHdHS09OrVS7Zu3VpmmZMnT0phYaHPBAAIfQENIA0fpWc8Jelzz2ulpaSkmJDyTM2bNw/kKgEAgpT1XnDJyclSUFDgnXJycmyvEgCgqgVQfHy8eczLy/OZr889r5UWGRkpUVFRPhMAIPQFNIBatWplgiY9Pd07T9t0tDdc7969A/lRAIBw6wV37NgxycrK8ul4kJmZKTExMdKiRQuZPHmyPPnkk9KuXTsTSDNmzDDXDA0fPjzQ6w4ACKcA2r59u1x33XXe51OnTjWPo0ePltTUVJk+fbq5VmjcuHGSn58vffr0kbS0NKldu3Zg1xwAUKVFOHrxThDRKjvtDQdcDK329Uf37t1dlyl52cEvlZGR4boMUNVox7Lztetb7wUHAAhPBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAVI3bMQChrLi42HWZ77//3nUZvX+WWwkJCa7LfPXVV67LAJWFMyAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsILBSBH0EhMTXZdp3769X59Vr14912UyMjJcl4mPj3dd5vDhw67LZGZmij+mT5/uuszOnTv9+iyEL86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKBiNF0Ktbt67rMlFRUVJZ/BlY9Omnn3ZdpkYN9/9cH3zwQfHH0qVLXZe5++67K22wVIQGzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoGI0XQGzVqlOsyERERfn3WqVOnXJeZNm2a6zILFiyQyrBmzRq/ym3cuNF1mQ0bNrgu06lTJ9dl8vLyXJdBcOIMCABgBQEEAKgaAbRlyxYZNmyYJCYmmmqO0qf4Y8aMMfNLTkOGDAnkOgMAwjGAioqKpGvXrrJw4cJyl9HAOXjwoHdasWLFxa4nACDcOyEMHTrUTOcTGRnp110iAQDho0LagDZv3ixNmjSRyy67TCZMmCBHjhwpd9mTJ09KYWGhzwQACH0BDyCtfnvllVckPT1dnnnmGcnIyDBnTGfOnClz+ZSUFImOjvZOzZs3D/QqAQDC4TqgktdsdO7cWbp06SJt2rQxZ0UDBw48Z/nk5GSZOnWq97meARFCABD6KrwbduvWrSU2NlaysrLKbS+KiorymQAAoa/CA+jAgQOmDSghIaGiPwoAEMpVcMeOHfM5m8nOzpbMzEyJiYkx0+OPPy4jR440veD27t0r06dPl7Zt20pSUlKg1x0AEE4BtH37drnuuuu8zz3tN6NHj5ZFixbJzp07ZenSpZKfn28uVh08eLD8+c9/NlVtAAD4HUD9+/cXx3HKfX39+vVu3xI4r/Ptb4Eso/bs2RO0A4v646OPPvKr3HPPPee6zIwZM1yXGTBggOsyXNgeOhgLDgBgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFZEOP4OG1xB9Jbc0dHRtlcDQaRfv36uy8yePduvz5o2bZr4c4uSUFOzZk3XZfbt2+e6TG5urusyPXv2dF0GdhQUFJz3LtecAQEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQxGCiAgFixY4LpMUlKS6zJdu3Z1Xaa4uNh1GVw8BiMFAAQlquAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVNex8LMJVXFyc6zKNGzd2XWbv3r3iDwat9F/79u1dl2nTpo3rMq1bt3Zd5quvvnJdBhWPMyAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsILBSOG3xMRE12W2bNniuszq1atdl5k5c6brMvifoUOHut4cvXr1cl3m1KlTrsucPHnSdRkEJ86AAABWEEAAgOAPoJSUFOnZs6c0aNBAmjRpIsOHD5fdu3f7LHPixAmZOHGiNGrUSOrXry8jR46UvLy8QK83ACCcAigjI8OEy7Zt22TDhg1y+vRpGTx4sBQVFXmXmTJlirz99tuyatUqs3xubq7ccsstFbHuAIBw6YSQlpbm8zw1NdWcCe3YsUP69u0rBQUF8vLLL8urr74qAwYMMMssWbJELr/8chNaV111VWDXHgAQnm1AGjgqJibGPGoQ6VnRoEGDvMt06NBBWrRoIVu3bi23R0thYaHPBAAIfX4H0NmzZ2Xy5MlyzTXXSKdOncy8Q4cOSa1ataRhw4Y+y8bFxZnXymtXio6O9k7Nmzf3d5UAAOEQQNoW9OWXX8prr712USuQnJxszqQ8U05OzkW9HwAghC9EnTRpkqxbt85cVNisWTPv/Pj4eHNhWX5+vs9ZkPaC09fKEhkZaSYAQHhxdQbkOI4JH70yfdOmTdKqVSuf13v06CE1a9aU9PR07zztpr1//37p3bt34NYaABBeZ0Ba7aY93NauXWuuBfK062jbTZ06dczjPffcI1OnTjUdE6KiouS+++4z4UMPOACA3wG0aNEi89i/f3+f+drVesyYMeb/n3/+ealWrZq5AFV7uCUlJcmLL77o5mMAAGGghtsquAupXbu2LFy40EwIbXoNmFv+9HL86aefXJcpLi52XSYUef4wdEt7p7qltSJuzZs3z3WZrKws12UQnBgLDgBgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFZEOL9kiOtKVFhYaO4rhND01VdfuS7TokUL12X0vlX++OSTT1yXOXLkiOsy5d0h+Hyuvvpq12UGDhwo/oiLi3Nd5sCBA67L6O1a3Nq1a5frMrCjoKDA3BeuPJwBAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVDEaKSnXFFVe4LvOvf/3LdZnY2FgJNREREa7L/Pzzz3591vfff++6zPjx412XWb9+vesyqDoYjBQAEJSoggMAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYwGCmCXvfu3SttkMuYmBjXZc6cOeO6zKpVqyplMNL58+eLPzIzM12XOXXqlF+fhdDFYKQAgKBEFRwAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCwUgBABWCwUgBAEGJKjgAQPAHUEpKivTs2VMaNGggTZo0keHDh8vu3bt9lunfv7+5b0nJafz48YFebwBAOAVQRkaGTJw4UbZt2yYbNmyQ06dPy+DBg6WoqMhnubFjx8rBgwe909y5cwO93gCAKq6Gm4XT0tJ8nqemppozoR07dkjfvn298+vWrSvx8fGBW0sAQMipdrE9HMq6jfHy5cslNjZWOnXqJMnJyXL8+PFy3+PkyZNSWFjoMwEAwoDjpzNnzjg33HCDc8011/jMf+mll5y0tDRn586dzrJly5ymTZs6I0aMKPd9Zs2a5ehqMLEN2AfYB9gHJKS2QUFBwXlzxO8AGj9+vNOyZUsnJyfnvMulp6ebFcnKyirz9RMnTpiV9Ez6frY3GhPbgH2AfYB9QCo8gFy1AXlMmjRJ1q1bJ1u2bJFmzZqdd9levXqZx6ysLGnTps05r0dGRpoJABBeXAWQnjHdd999snr1atm8ebO0atXqgmUyMzPNY0JCgv9rCQAI7wDSLtivvvqqrF271lwLdOjQITM/Ojpa6tSpI3v37jWvX3/99dKoUSPZuXOnTJkyxfSQ69KlS0V9BwBAVeSm3ae8er4lS5aY1/fv3+/07dvXiYmJcSIjI522bds606ZNu2A9YEm6LHWv1L+zD7APsA9U/X3gQsd+BiMFAFQIBiMFAAQlBiMFAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwIugCyHEc26sAAKiE43nQBdDRo0dtrwIAoBKO5xFOkJ1ynD17VnJzc6VBgwYSERHh81phYaE0b95ccnJyJCoqSsIV24HtwP7Av4tgPj5orGj4JCYmSrVq5Z/n1JAgoyvbrFmz8y6jGzWcA8iD7cB2YH/g30WwHh+io6MvuEzQVcEBAMIDAQQAsKJKBVBkZKTMmjXLPIYztgPbgf2BfxehcHwIuk4IAIDwUKXOgAAAoYMAAgBYQQABAKwggAAAVhBAAAArqkwALVy4UC699FKpXbu29OrVSz7++GPbq1TpZs+ebYYnKjl16NBBQt2WLVtk2LBhZlgP/c5r1qzxeV07cs6cOVMSEhKkTp06MmjQINmzZ4+E23YYM2bMOfvHkCFDJJSkpKRIz549zVBdTZo0keHDh8vu3bt9ljlx4oRMnDhRGjVqJPXr15eRI0dKXl6ehNt26N+//zn7w/jx4yWYVIkAev3112Xq1Kmmb/unn34qXbt2laSkJPnhhx8k3HTs2FEOHjzonT744AMJdUVFReY31z9CyjJ37lx54YUXZPHixfLRRx9JvXr1zP6hB6Jw2g5KA6fk/rFixQoJJRkZGSZctm3bJhs2bJDTp0/L4MGDzbbxmDJlirz99tuyatUqs7yOLXnLLbdIuG0HNXbsWJ/9Qf+tBBWnCrjyyiudiRMnep+fOXPGSUxMdFJSUpxwMmvWLKdr165OONNddvXq1d7nZ8+edeLj451nn33WOy8/P9+JjIx0VqxY4YTLdlCjR492br75Ziec/PDDD2ZbZGRkeH/7mjVrOqtWrfIu880335hltm7d6oTLdlD9+vVzHnjgASeYBf0Z0KlTp2THjh2mWqXkgKX6fOvWrRJutGpJq2Bat24td911l+zfv1/CWXZ2thw6dMhn/9BBELWaNhz3j82bN5sqmcsuu0wmTJggR44ckVBWUFBgHmNiYsyjHiv0bKDk/qDV1C1atAjp/aGg1HbwWL58ucTGxkqnTp0kOTlZjh8/LsEk6EbDLu3HH3+UM2fOSFxcnM98fb5r1y4JJ3pQTU1NNQcXPZ1+/PHH5dprr5Uvv/zS1AWHIw0fVdb+4XktXGj1m1Y1tWrVSvbu3SuPPvqoDB061Bx4q1evLqFGb90yefJkueaaa8wBVulvXqtWLWnYsGHY7A9ny9gO6s4775SWLVuaP1h37twpDz/8sGkneuuttyRYBH0A4X/0YOLRpUsXE0i6g61cuVLuueceNlWYGzVqlPf/O3fubPaRNm3amLOigQMHSqjRNhD94ysc2kH92Q7jxo3z2R+0k47uB/rHie4XwSDoq+D09FH/eivdi0Wfx8fHSzjTv/Lat28vWVlZEq48+wD7x7m0mlb//YTi/jFp0iRZt26dvPfeez73D9P9Qavt8/Pzw+J4Mamc7VAW/YNVBdP+EPQBpKfTPXr0kPT0dJ9TTn3eu3dvCWfHjh0zf83oXzbhSqub9MBScv/QO0Jqb7hw3z8OHDhg2oBCaf/Q/hd60F29erVs2rTJ/P4l6bGiZs2aPvuDVjtpW2ko7Q/OBbZDWTIzM81jUO0PThXw2muvmV5Nqampztdff+2MGzfOadiwoXPo0CEnnDz44IPO5s2bnezsbOc///mPM2jQICc2Ntb0gAllR48edT777DMz6S47b9488//fffedeX3OnDlmf1i7dq2zc+dO0xOsVatWTnFxsRMu20Ffe+ihh0xPL90/Nm7c6HTv3t1p166dc+LECSdUTJgwwYmOjjb/Dg4ePOidjh8/7l1m/PjxTosWLZxNmzY527dvd3r37m2mUDLhAtshKyvLeeKJJ8z31/1B/220bt3a6du3rxNMqkQAqQULFpidqlatWqZb9rZt25xwc/vttzsJCQlmGzRt2tQ81x0t1L333nvmgFt60m7Hnq7YM2bMcOLi4swfKgMHDnR2797thNN20APP4MGDncaNG5tuyC1btnTGjh0bcn+klfX9dVqyZIl3Gf3D495773UuueQSp27dus6IESPMwTmctsP+/ftN2MTExJh/E23btnWmTZvmFBQUOMGE+wEBAKwI+jYgAEBoIoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQAIIABA+OAMCAAgNvw/zJG9Fliu2m4AAAAASUVORK5CYII=", | |
| "text/plain": [ | |
| "<Figure size 640x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "id = X.sample(1).index[0]\n", | |
| "plt.imshow(X.iloc[id].values.reshape(28, 28), cmap='gray')\n", | |
| "plt.title(f'Id: {id}, Label: {y[id]}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "58ba3512", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "y = y.astype(\"int16\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "7d25ace7", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1.3. Train test split" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "18929619", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(69000, 784)\n", | |
| "(1000, 784)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from sklearn.model_selection import train_test_split\n", | |
| "\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X.values, y.values, test_size=1000, random_state=42, shuffle=True)\n", | |
| "print(X_train.shape)\n", | |
| "print(X_test.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "180bf32b", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1.4. Scaling" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "af330d6f", | |
| "metadata": {}, | |
| "source": [ | |
| "A escala dos valores de entrada para a rede é bastante importante para que não haja nenhum problema no treinamento. \n", | |
| "Aqui poderíamos aplicar o standard scaler, mas como sabemos o intervalo de valores que os pixels podem assumir, \n", | |
| "vamos fazer um simples min max." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "040c3575", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "99be654f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(69000, 784)" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_train.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "ea8b393d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(784,)" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Standard Scaler\n", | |
| "# Calculando média e desvio padrão no conjunto de treino\n", | |
| "means = X_train.mean(axis=0)\n", | |
| "stds = X_train.std(axis=0)\n", | |
| "stds.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "7967eb25", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Min max scaler\n", | |
| "subtract = [0.0] * X_train.shape[1]\n", | |
| "divide = [255.0] * X_train.shape[1]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "e144fa81", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def apply_scaling(X_train, subtract, divide):\n", | |
| " cols_scaled = []\n", | |
| " for i in range(X_train.shape[1]):\n", | |
| " col = X_train[:, i]\n", | |
| " if divide[i] == 0:\n", | |
| " # print(f\"Col {i}: Std == 0\")\n", | |
| " col_scaled = col\n", | |
| " else:\n", | |
| " # print(f\"Col {i}: Std != 0\")\n", | |
| " col_scaled = (col - subtract[i]) / divide[i]\n", | |
| " col_scaled = col_scaled.reshape(-1, 1)\n", | |
| " # print(col_scaled.shape)\n", | |
| " cols_scaled.append(col_scaled)\n", | |
| " X_train_scaled = np.hstack(cols_scaled)\n", | |
| " print(X_train_scaled.shape)\n", | |
| " return X_train_scaled" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "3e2a5ec1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(69000, 784)\n", | |
| "(1000, 784)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Min max scaler\n", | |
| "X_train_scaled = apply_scaling(X_train, [0.0] * X_train.shape[1], [255.0] * X_train.shape[1])\n", | |
| "X_test_scaled = apply_scaling(X_test, [0.0] * X_test.shape[1], [255.0] * X_test.shape[1])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "2077e5e7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>0</th>\n", | |
| " <th>1</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>count</th>\n", | |
| " <td>69000.0000</td>\n", | |
| " <td>69000.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>mean</th>\n", | |
| " <td>12.9836</td>\n", | |
| " <td>11.5725</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>std</th>\n", | |
| " <td>49.7809</td>\n", | |
| " <td>47.1865</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>min</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>25%</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>50%</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>75%</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>max</th>\n", | |
| " <td>255.0000</td>\n", | |
| " <td>255.0000</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " 0 1\n", | |
| "count 69000.0000 69000.0000\n", | |
| "mean 12.9836 11.5725\n", | |
| "std 49.7809 47.1865\n", | |
| "min 0.0000 0.0000\n", | |
| "25% 0.0000 0.0000\n", | |
| "50% 0.0000 0.0000\n", | |
| "75% 0.0000 0.0000\n", | |
| "max 255.0000 255.0000" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pd.DataFrame(X_train[:, 100:102]).describe().round(4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "9823c865", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>0</th>\n", | |
| " <th>1</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>count</th>\n", | |
| " <td>69000.0000</td>\n", | |
| " <td>69000.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>mean</th>\n", | |
| " <td>0.0509</td>\n", | |
| " <td>0.0454</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>std</th>\n", | |
| " <td>0.1952</td>\n", | |
| " <td>0.1850</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>min</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>25%</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>50%</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>75%</th>\n", | |
| " <td>0.0000</td>\n", | |
| " <td>0.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>max</th>\n", | |
| " <td>1.0000</td>\n", | |
| " <td>1.0000</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " 0 1\n", | |
| "count 69000.0000 69000.0000\n", | |
| "mean 0.0509 0.0454\n", | |
| "std 0.1952 0.1850\n", | |
| "min 0.0000 0.0000\n", | |
| "25% 0.0000 0.0000\n", | |
| "50% 0.0000 0.0000\n", | |
| "75% 0.0000 0.0000\n", | |
| "max 1.0000 1.0000" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pd.DataFrame(X_train_scaled[:, 100:102]).describe().round(4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "15aded6d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(784, 69000)\n", | |
| "(784, 1000)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Aqui vamos transpor os dados para facilitar as multiplicações de matrizes\n", | |
| "X_train_scaled, X_test_scaled, y_train, y_test = X_train_scaled.T, X_test_scaled.T, y_train.T, y_test.T\n", | |
| "print(X_train_scaled.shape)\n", | |
| "print(X_test_scaled.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "096d2430", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(784,)" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_train_scaled[:, 0].shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0b41f758", | |
| "metadata": {}, | |
| "source": [ | |
| "# 2. Funções de ativação" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "fb84d8e4", | |
| "metadata": {}, | |
| "source": [ | |
| "A escolha das funções de ativação das hidden layers são arbitrárias, contanto que as derivadas sejam calculáveis. \n", | |
| "Mas a camada de saída deve passar por uma função que corresponda à tarefa. \n", | |
| "No caso da Classificação Multiclass, vamos usar a softmax pra transformar os valores de saída em uma distribuição \n", | |
| "de probabilidade cuja soma é 1." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "97f7a916", | |
| "metadata": {}, | |
| "source": [ | |
| "" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "c1bcba9b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "id": "8032439b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Funções de ativação\n", | |
| "\n", | |
| "def sigmoid(z):\n", | |
| " return 1 / (1 + np.exp(-z))\n", | |
| "\n", | |
| "def step(z):\n", | |
| " return np.where(z >= 0, 1, 0)\n", | |
| "\n", | |
| "def relu(z):\n", | |
| " return np.maximum(0, z)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "id": "51a78688", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.1, 0.2, 0.3]])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.30060961, 0.33222499, 0.3671654 ]])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "np.float64(1.0000000000000002)" | |
| ] | |
| }, | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Softmax simples - erro de arredondamento\n", | |
| "def softmax(z):\n", | |
| " return np.exp(z) / np.sum(np.exp(z))\n", | |
| "\n", | |
| "x = np.array([[0.1, 0.2, 0.3]])\n", | |
| "display(x)\n", | |
| "display(softmax(x))\n", | |
| "softmax(x).sum()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "id": "a63b144d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.1, 0.2, 0.3]])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.30060961, 0.33222499, 0.3671654 ]])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "np.float64(1.0)" | |
| ] | |
| }, | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Softmax com fator de estabilidade numérica\n", | |
| "def softmax(x, axis=0):\n", | |
| " # Subtract the maximum value for numerical stability\n", | |
| " # keepdims=True ensures the shape is maintained for correct broadcasting\n", | |
| " x_max = np.max(x, axis=axis, keepdims=True)\n", | |
| " e_x = np.exp(x - x_max)\n", | |
| " \n", | |
| " # Divide by the sum to obtain probabilities that sum to 1\n", | |
| " return e_x / e_x.sum(axis=axis, keepdims=True)\n", | |
| "\n", | |
| "x = np.array([[0.1, 0.2, 0.3]])\n", | |
| "display(x)\n", | |
| "output = softmax(x, axis=1)\n", | |
| "display(output)\n", | |
| "output.sum()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "5797a04a", | |
| "metadata": {}, | |
| "source": [ | |
| "# 3. Inicialização dos parâmetros treináveis" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9c609550", | |
| "metadata": {}, | |
| "source": [ | |
| "Aqui os parâmetros de pesos e biases são criados. \n", | |
| "A forma com que esses parâmetros são inicializados é muito importante para o treinamento. \n", | |
| "A inicialização com valores aleatórios é bem simples, mas existem formas melhores." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "id": "376cfdc5", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "W1.shape=(12, 784)\n", | |
| "b1.shape=(12, 1)\n", | |
| "W2.shape=(10, 12)\n", | |
| "b2.shape=(10, 1)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# https://numpy.org/doc/stable/reference/random/generated/numpy.random.randn.html\n", | |
| "\n", | |
| "def init_params_standard_normal(input_len=784, n_hidden = 12, n_output = 10, sigma=1, mu=0):\n", | |
| " W1 = np.random.randn(n_hidden, input_len) * sigma + mu\n", | |
| " b1 = np.random.randn(n_hidden, 1) * sigma + mu\n", | |
| " W2 = np.random.randn(n_output, n_hidden) * sigma + mu\n", | |
| " b2 = np.random.randn(n_output, 1) * sigma + mu\n", | |
| " print(f\"{W1.shape=}\\n{b1.shape=}\\n{W2.shape=}\\n{b2.shape=}\")\n", | |
| " return W1, b1, W2, b2\n", | |
| "\n", | |
| "W1, b1, W2, b2 = init_params_standard_normal()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "34f15d4c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>0</th>\n", | |
| " <th>1</th>\n", | |
| " <th>2</th>\n", | |
| " <th>3</th>\n", | |
| " <th>4</th>\n", | |
| " <th>5</th>\n", | |
| " <th>6</th>\n", | |
| " <th>7</th>\n", | |
| " <th>8</th>\n", | |
| " <th>9</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>count</th>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " <td>12.0000</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>mean</th>\n", | |
| " <td>0.2374</td>\n", | |
| " <td>-0.0184</td>\n", | |
| " <td>-0.2124</td>\n", | |
| " <td>0.0696</td>\n", | |
| " <td>-0.6559</td>\n", | |
| " <td>0.1702</td>\n", | |
| " <td>0.0769</td>\n", | |
| " <td>-0.1195</td>\n", | |
| " <td>0.2435</td>\n", | |
| " <td>-0.2746</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>std</th>\n", | |
| " <td>1.0505</td>\n", | |
| " <td>0.9126</td>\n", | |
| " <td>0.9761</td>\n", | |
| " <td>1.2073</td>\n", | |
| " <td>0.7606</td>\n", | |
| " <td>1.0797</td>\n", | |
| " <td>0.7354</td>\n", | |
| " <td>0.5563</td>\n", | |
| " <td>0.7472</td>\n", | |
| " <td>0.8609</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>min</th>\n", | |
| " <td>-1.3007</td>\n", | |
| " <td>-1.7274</td>\n", | |
| " <td>-2.2391</td>\n", | |
| " <td>-3.0794</td>\n", | |
| " <td>-2.0968</td>\n", | |
| " <td>-2.0855</td>\n", | |
| " <td>-1.3020</td>\n", | |
| " <td>-1.2476</td>\n", | |
| " <td>-0.8955</td>\n", | |
| " <td>-1.9726</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>25%</th>\n", | |
| " <td>-0.3717</td>\n", | |
| " <td>-0.4912</td>\n", | |
| " <td>-0.6852</td>\n", | |
| " <td>-0.2138</td>\n", | |
| " <td>-0.8213</td>\n", | |
| " <td>-0.3437</td>\n", | |
| " <td>-0.2522</td>\n", | |
| " <td>-0.3849</td>\n", | |
| " <td>-0.3035</td>\n", | |
| " <td>-0.8565</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>50%</th>\n", | |
| " <td>0.0006</td>\n", | |
| " <td>-0.0520</td>\n", | |
| " <td>-0.1960</td>\n", | |
| " <td>0.3400</td>\n", | |
| " <td>-0.5496</td>\n", | |
| " <td>0.1192</td>\n", | |
| " <td>0.2696</td>\n", | |
| " <td>-0.1851</td>\n", | |
| " <td>0.0482</td>\n", | |
| " <td>-0.1642</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>75%</th>\n", | |
| " <td>0.6263</td>\n", | |
| " <td>0.6239</td>\n", | |
| " <td>0.6728</td>\n", | |
| " <td>0.8600</td>\n", | |
| " <td>-0.0104</td>\n", | |
| " <td>0.5433</td>\n", | |
| " <td>0.6566</td>\n", | |
| " <td>0.1980</td>\n", | |
| " <td>0.8000</td>\n", | |
| " <td>0.3913</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>max</th>\n", | |
| " <td>2.4439</td>\n", | |
| " <td>1.1715</td>\n", | |
| " <td>1.0374</td>\n", | |
| " <td>1.2599</td>\n", | |
| " <td>0.2005</td>\n", | |
| " <td>2.2942</td>\n", | |
| " <td>0.9208</td>\n", | |
| " <td>0.9672</td>\n", | |
| " <td>1.7092</td>\n", | |
| " <td>0.8682</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " 0 1 2 3 4 5 6 7 \\\n", | |
| "count 12.0000 12.0000 12.0000 12.0000 12.0000 12.0000 12.0000 12.0000 \n", | |
| "mean 0.2374 -0.0184 -0.2124 0.0696 -0.6559 0.1702 0.0769 -0.1195 \n", | |
| "std 1.0505 0.9126 0.9761 1.2073 0.7606 1.0797 0.7354 0.5563 \n", | |
| "min -1.3007 -1.7274 -2.2391 -3.0794 -2.0968 -2.0855 -1.3020 -1.2476 \n", | |
| "25% -0.3717 -0.4912 -0.6852 -0.2138 -0.8213 -0.3437 -0.2522 -0.3849 \n", | |
| "50% 0.0006 -0.0520 -0.1960 0.3400 -0.5496 0.1192 0.2696 -0.1851 \n", | |
| "75% 0.6263 0.6239 0.6728 0.8600 -0.0104 0.5433 0.6566 0.1980 \n", | |
| "max 2.4439 1.1715 1.0374 1.2599 0.2005 2.2942 0.9208 0.9672 \n", | |
| "\n", | |
| " 8 9 \n", | |
| "count 12.0000 12.0000 \n", | |
| "mean 0.2435 -0.2746 \n", | |
| "std 0.7472 0.8609 \n", | |
| "min -0.8955 -1.9726 \n", | |
| "25% -0.3035 -0.8565 \n", | |
| "50% 0.0482 -0.1642 \n", | |
| "75% 0.8000 0.3913 \n", | |
| "max 1.7092 0.8682 " | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pd.DataFrame(W1[:, 0:10]).describe().round(4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "e376e808", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "W1.shape=(12, 784)\n", | |
| "b1.shape=(12, 1)\n", | |
| "W2.shape=(10, 12)\n", | |
| "b2.shape=(10, 1)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def init_params_uniform(input_len=784, n_hidden = 12, n_output = 10, low=-0.5, high=0.5):\n", | |
| " W1 = np.random.uniform(low, high, (n_hidden, input_len))\n", | |
| " b1 = np.random.uniform(low, high, (n_hidden, 1))\n", | |
| " W2 = np.random.uniform(low, high, (n_output, n_hidden))\n", | |
| " b2 = np.random.uniform(low, high, (n_output, 1))\n", | |
| " print(f\"{W1.shape=}\\n{b1.shape=}\\n{W2.shape=}\\n{b2.shape=}\")\n", | |
| " return W1, b1, W2, b2\n", | |
| "\n", | |
| "W1, b1, W2, b2 = init_params_uniform()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "id": "4420e86b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "W1.shape=(12, 784)\n", | |
| "b1.shape=(12, 1)\n", | |
| "W2.shape=(10, 12)\n", | |
| "b2.shape=(10, 1)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def init_params_glorot(input_len=784, n_hidden=12, n_output=10):\n", | |
| " # Glorot uniform initialization\n", | |
| " limit_w1 = np.sqrt(6.0 / (input_len + n_hidden))\n", | |
| " W1 = np.random.uniform(-limit_w1, limit_w1, (n_hidden, input_len))\n", | |
| " b1 = np.zeros((n_hidden, 1))\n", | |
| " \n", | |
| " limit_w2 = np.sqrt(6.0 / (n_hidden + n_output))\n", | |
| " W2 = np.random.uniform(-limit_w2, limit_w2, (n_output, n_hidden))\n", | |
| " b2 = np.zeros((n_output, 1))\n", | |
| " \n", | |
| " print(f\"{W1.shape=}\\n{b1.shape=}\\n{W2.shape=}\\n{b2.shape=}\")\n", | |
| " return W1, b1, W2, b2\n", | |
| "\n", | |
| "W1, b1, W2, b2 = init_params_glorot()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "914e78f1", | |
| "metadata": {}, | |
| "source": [ | |
| "# 4. Forward pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "id": "1a23d3eb", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ACTIVATION_FUNCTION = relu\n", | |
| "OUTPUT_FUNCTION = softmax" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "id": "062666d1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Z1.shape=(12, 2)\n", | |
| "A1.shape=(12, 2)\n", | |
| "Z2.shape=(10, 2)\n", | |
| "A2.shape=(10, 2)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def forward(X, W1, b1, W2, b2, verbose=False):\n", | |
| " Z1 = W1 @ X + b1\n", | |
| " A1 = ACTIVATION_FUNCTION(Z1)\n", | |
| " Z2 = W2 @ A1 + b2\n", | |
| " A2 = OUTPUT_FUNCTION(Z2, axis=0)\n", | |
| " if verbose:\n", | |
| " print(f\"{Z1.shape=}\\n{A1.shape=}\\n{Z2.shape=}\\n{A2.shape=}\")\n", | |
| " return Z1, A1, Z2, A2\n", | |
| "Z1, A1, Z2, A2 = forward(X_train_scaled[:,0:2], W1, b1, W2, b2, True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "id": "8fb9a23d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[0.07478467 0.12267795]\n", | |
| " [0.11597149 0.10304144]\n", | |
| " [0.1204467 0.09488816]\n", | |
| " [0.08310441 0.06405648]\n", | |
| " [0.14265181 0.13316973]\n", | |
| " [0.07147575 0.10998027]\n", | |
| " [0.10330583 0.096492 ]\n", | |
| " [0.11945251 0.09841706]\n", | |
| " [0.08345286 0.09008889]\n", | |
| " [0.08535398 0.08718803]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# A2 é a saída da rede\n", | |
| "print(A2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "c1ca11cc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([4, 4])" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# O argmax equivale à classe predita\n", | |
| "A2.argmax(axis=0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "id": "1128c475", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([1., 1.])" | |
| ] | |
| }, | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Pelo softmax, soma das probas de cada saída é 1\n", | |
| "A2.sum(axis=0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3b011ec0", | |
| "metadata": {}, | |
| "source": [ | |
| "# 5. One hot" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "004504d4", | |
| "metadata": {}, | |
| "source": [ | |
| "Transforma o vetor de labels (números entre 0 e 9) em uma matriz onde cada coluna equivale à uma label." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "00e7abe4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([ 0, 1, 2, ..., 68997, 68998, 68999], shape=(69000,))" | |
| ] | |
| }, | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.arange(y_train.shape[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "66d28776", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "y_train.shape=(69000,)\n", | |
| "y_train_one_hot.shape=(10, 69000)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def one_hot(y, y_max=None):\n", | |
| " if y_max is None:\n", | |
| " y_max = y.max()\n", | |
| " one_hot = np.zeros((y.shape[0], y_max + 1))\n", | |
| " indexes = np.arange(y.shape[0])\n", | |
| " one_hot[indexes, y] = 1\n", | |
| " return one_hot.T\n", | |
| "\n", | |
| "y_train_one_hot = one_hot(y_train)\n", | |
| "print(f\"{y_train.shape=}\\n{y_train_one_hot.shape=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "16f692f6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[5 3 9]\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0., 0., 0.],\n", | |
| " [0., 0., 0.],\n", | |
| " [0., 0., 0.],\n", | |
| " [0., 1., 0.],\n", | |
| " [0., 0., 0.],\n", | |
| " [1., 0., 0.],\n", | |
| " [0., 0., 0.],\n", | |
| " [0., 0., 0.],\n", | |
| " [0., 0., 0.],\n", | |
| " [0., 0., 1.]])" | |
| ] | |
| }, | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "print(y_train[5:8])\n", | |
| "y_train_one_hot[:, 5:8]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "fa715e8c", | |
| "metadata": {}, | |
| "source": [ | |
| "# 6. Loss Function" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "cd79d70d", | |
| "metadata": {}, | |
| "source": [ | |
| "A função de custo da classificação multiclasse costuma ser a Cross Entropy, as vezes chamada de \n", | |
| "Categorical Cross Entropy." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ff7beba3", | |
| "metadata": {}, | |
| "source": [ | |
| "" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "id": "3082220d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Z1.shape=(12, 5)\n", | |
| "A1.shape=(12, 5)\n", | |
| "Z2.shape=(10, 5)\n", | |
| "A2.shape=(10, 5)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.075, 0.123, 0.106, 0.084, 0.121],\n", | |
| " [0.116, 0.103, 0.09 , 0.101, 0.105],\n", | |
| " [0.12 , 0.095, 0.11 , 0.125, 0.097],\n", | |
| " [0.083, 0.064, 0.088, 0.104, 0.071],\n", | |
| " [0.143, 0.133, 0.139, 0.127, 0.16 ],\n", | |
| " [0.071, 0.11 , 0.081, 0.066, 0.063],\n", | |
| " [0.103, 0.096, 0.088, 0.094, 0.077],\n", | |
| " [0.119, 0.098, 0.106, 0.142, 0.136],\n", | |
| " [0.083, 0.09 , 0.086, 0.06 , 0.093],\n", | |
| " [0.085, 0.087, 0.107, 0.097, 0.077]])" | |
| ] | |
| }, | |
| "execution_count": 39, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Z1, A1, Z2, A2 = forward(X_train_scaled[:,0:5], W1, b1, W2, b2, True)\n", | |
| "A2.round(3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "fc761701", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Cross-entropy loss\n", | |
| "# Epsilon é um valor pequeno para evitar log(0)\n", | |
| "def cross_entropy_loss(A2, Y, y_max=None, epsilon=1e-15):\n", | |
| " m = Y.shape[0]\n", | |
| " one_hot_Y = one_hot(Y, y_max)\n", | |
| " logprobs = np.log(A2 + epsilon) * one_hot_Y\n", | |
| " losses = -np.sum(logprobs)\n", | |
| " loss = losses / m # O valor final é a média\n", | |
| " return loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "id": "7c1c8765", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "2.314" | |
| ] | |
| }, | |
| "execution_count": 41, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "cross_entropy_loss(A2, y_train[0:5], y_max=9).round(4).item()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "56e5ed10", | |
| "metadata": {}, | |
| "source": [ | |
| "# 7. Backward pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0fb18890", | |
| "metadata": {}, | |
| "source": [ | |
| "Aqui vamos escrever as respectivas derivadas" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "4a811c61", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def relu_derivative(z):\n", | |
| " return np.where(z > 0, 1, 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "aed088dc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Z1 = W1 @ X + b1\n", | |
| "# A1 = relu(Z1)\n", | |
| "# Z2 = W2 @ A1 + b2\n", | |
| "# A2 = softmax(Z2)\n", | |
| "# loss = cross_entropy_loss(A2, Y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "id": "86e6f662", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def backward(Z1, A1, Z2, A2, W1, W2, X, Y, verbose=False):\n", | |
| " m = Y.shape[0]\n", | |
| " one_hot_Y = one_hot(Y)\n", | |
| "\n", | |
| " # 1. Derivada da função de perda em relação aos pesos da camada de saída\n", | |
| " dZ2 = A2 - one_hot_Y # Derivada da cross-entropy combinada com softmax simplifica para A2 - Y_one_hot\n", | |
| " dW2 = dZ2 @ A1.T / m\n", | |
| " db2 = np.sum(dZ2, axis=1, keepdims=True) / m\n", | |
| "\n", | |
| " # 2. Derivada da função de perda em relação aos pesos da camada oculta\n", | |
| " dZ1 = (W2.T @ dZ2) * relu_derivative(Z1)\n", | |
| " dW1 = dZ1 @ X.T / m\n", | |
| " db1 = np.sum(dZ1, axis=1, keepdims=True) / m\n", | |
| "\n", | |
| " if verbose:\n", | |
| " print(f\"{dW1.shape=}\\n{db1.shape=}\\n{dW2.shape=}\\n{db2.shape=}\")\n", | |
| " return dW1, db1, dW2, db2\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "0aa97361", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "dW1.shape=(12, 784)\n", | |
| "db1.shape=(12, 1)\n", | |
| "dW2.shape=(10, 12)\n", | |
| "db2.shape=(10, 1)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "_ = backward(Z1, A1, Z2, A2, W1, W2, X_train_scaled[:,0:5], y_train[0:5], True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "cb6b46be", | |
| "metadata": {}, | |
| "source": [ | |
| "# 8. Atualização dos parâmetros treináveis" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "701a2f02", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n", | |
| " W1 = W1 - dW1 * alpha\n", | |
| " b1 = b1 - db1 * alpha\n", | |
| " W2 = W2 - dW2 * alpha\n", | |
| " b2 = b2 - db2 * alpha\n", | |
| " return W1, b1, W2, b2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6cf7991d", | |
| "metadata": {}, | |
| "source": [ | |
| "# 9. Gradient Descent\n", | |
| "(aqui é o algoritmo original mais básico mesmo)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "0ae7c1bc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_predictions(A2):\n", | |
| " return np.argmax(A2, 0)\n", | |
| "\n", | |
| "def get_accuracy(predictions, Y):\n", | |
| " # print(predictions, Y)\n", | |
| " return np.sum(predictions == Y) / Y.size" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "id": "4f199e35", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def make_predictions(X, W1, b1, W2, b2):\n", | |
| " _, _, _, A2 = forward(X, W1, b1, W2, b2)\n", | |
| " predictions = get_predictions(A2)\n", | |
| " return predictions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "98da370d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def gradient_descent(X_train, y_train, X_test, y_test, n_iter, alpha, params):\n", | |
| " W1, b1, W2, b2 = params\n", | |
| " for i in range(n_iter):\n", | |
| " Z1, A1, Z2, A2 = forward(X_train, W1, b1, W2, b2)\n", | |
| " dW1, db1, dW2, db2 = backward(Z1, A1, Z2, A2, W1, W2, X_train, y_train)\n", | |
| " W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n", | |
| " if i % 10 == 0:\n", | |
| " # Calculando métricas de treino\n", | |
| " train_acc = get_accuracy(get_predictions(A2), y_train).round(3).item()\n", | |
| " train_loss = cross_entropy_loss(A2, y_train).round(3).item()\n", | |
| "\n", | |
| " # Calculando métricas de teste\n", | |
| " A2_test = make_predictions(X_test, W1, b1, W2, b2)\n", | |
| " test_acc = get_accuracy(A2_test, y_test).round(3).item()\n", | |
| " test_loss = cross_entropy_loss(A2_test, y_test).round(3).item()\n", | |
| "\n", | |
| " print(f\"It {i:>4} - TRAIN loss: {str(train_loss):<5} acc: {str(train_acc):<5} - TEST loss: {str(test_loss):<5} acc: {str(test_acc):<5}\")\n", | |
| " return W1, b1, W2, b2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "df165e55", | |
| "metadata": {}, | |
| "source": [ | |
| "# 10. Treinamento" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "c5d3727a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "W1.shape=(12, 784)\n", | |
| "b1.shape=(12, 1)\n", | |
| "W2.shape=(10, 12)\n", | |
| "b2.shape=(10, 1)\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "It 0 - TRAIN loss: 2.327 acc: 0.122 - TEST loss: 2.376 acc: 0.136\n", | |
| "It 10 - TRAIN loss: 2.212 acc: 0.251 - TEST loss: 5.614 acc: 0.271\n", | |
| "It 20 - TRAIN loss: 2.081 acc: 0.307 - TEST loss: 8.208 acc: 0.305\n", | |
| "It 30 - TRAIN loss: 1.921 acc: 0.345 - TEST loss: 7.35 acc: 0.344\n", | |
| "It 40 - TRAIN loss: 1.744 acc: 0.431 - TEST loss: 5.181 acc: 0.442\n", | |
| "It 50 - TRAIN loss: 1.557 acc: 0.525 - TEST loss: 4.0 acc: 0.535\n", | |
| "It 60 - TRAIN loss: 1.371 acc: 0.6 - TEST loss: 3.319 acc: 0.602\n", | |
| "It 70 - TRAIN loss: 1.207 acc: 0.655 - TEST loss: 2.876 acc: 0.651\n", | |
| "It 80 - TRAIN loss: 1.075 acc: 0.697 - TEST loss: 2.571 acc: 0.689\n", | |
| "It 90 - TRAIN loss: 0.971 acc: 0.727 - TEST loss: 2.454 acc: 0.73 \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "params = init_params_standard_normal(sigma=0.1, mu=0)\n", | |
| "# params = init_params_uniform(low=-0.5, high=0.5)\n", | |
| "# params = init_params_glorot()\n", | |
| "W1, b1, W2, b2 = gradient_descent(X_train_scaled, y_train, X_test_scaled, y_test, 100, 0.1, params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "id": "b498f13d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.75" | |
| ] | |
| }, | |
| "execution_count": 51, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "get_accuracy(make_predictions(X_test_scaled, W1, b1, W2, b2), y_test).item()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "7e57815a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Prediction: [8]\n", | |
| "Label: 3\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGcNJREFUeJzt3XuMFdXhB/CzqCyo7NIV2YeAgqg0IjS1iESlWAmIjRU1jViTYmM0UDQV6iPbqGhbs61tWrWl2j8aqa1vUyASS6KrQGpBI5RS+6AuQYHCYjVll0dBuswvM7/fblkB/d3LLufuvZ9PMrnce+fsHM6ene89M+fOlCVJkgQAOMp6He0NAoAAAiAaIyAAohBAAEQhgACIQgABEIUAAiAKAQRAFMeGArN///6wZcuW0K9fv1BWVha7OgDkKL2+wY4dO0JdXV3o1atXzwmgNHwGDx4cuxoAHKFNmzaFQYMG9ZxDcOnIB4Ce75P2590WQPPmzQunnXZa6NOnTxg7dmx44403/l/lHHYDKA6ftD/vlgB65plnwpw5c8LcuXPD6tWrw+jRo8PkyZPDe++91x2bA6AnSrrBeeedl8yaNavjeVtbW1JXV5c0NDR8YtmWlpb06twWbaAP6AP6QOjZbZDuzz9Ol4+APvzww7Bq1aowceLEjtfSWRDp8xUrVhy0/t69e0Nra2unBYDi1+UB9P7774e2trZQXV3d6fX0eXNz80HrNzQ0hMrKyo7FDDiA0hB9Flx9fX1oaWnpWNJpewAUvy7/HtCAAQPCMcccE7Zt29bp9fR5TU3NQeuXl5dnCwClpctHQL179w7nnntuaGxs7HR1g/T5uHHjunpzAPRQ3XIlhHQK9vTp08PnPve5cN5554UHH3ww7Nq1K3zta1/rjs0B0AN1SwBdc8014Z///Ge45557sokHn/nMZ8KSJUsOmpgAQOkqS+dihwKSTsNOZ8MB0LOlE8sqKioKdxYcAKVJAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBIAAAqB0GAEBEIUAAiCKY+NsFjgaxo8fn1e5nTt35lxm7NixOZeZOXNmzmXOPvvsnMtcccUVIR+LFy/Oqxz/P0ZAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiCKsiRJklBAWltbQ2VlZexqQMH58pe/nHOZefPm5bWtffv25Vymuro6FKrt27fnVW7GjBk5l3n++efz2lYxamlpCRUVFYd93wgIgCgEEADFEUD33ntvKCsr67SMGDGiqzcDQA/XLTekS28Y9fLLL/93I8e67x0AnXVLMqSBU1NT0x0/GoAi0S3ngN5+++1QV1cXhg0bFq677rqwcePGw667d+/ebObbgQsAxa/LAyi9L/z8+fPDkiVLwiOPPBI2bNgQLrroorBjx45Drt/Q0JBNu25fBg8e3NVVAqAUAmjKlCnZ9xVGjRoVJk+eHF588cVsDv6zzz57yPXr6+uzueLty6ZNm7q6SgAUoG6fHdC/f/9w5plnhqampkO+X15eni0AlJZu/x7Qzp07w/r160NtbW13bwqAUg6g2267LSxbtiy888474fe//3248sorwzHHHBOuvfbart4UAD1Ylx+C27x5cxY2H3zwQTj55JPDhRdeGFauXJn9GwDauRgp9BBr1qzJuczIkSO7pS6lIp+LmF5zzTU5l2lsbAzFyMVIAShILkYKQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAxXlDOoCeKr2hZq6GDx+ec5nGIr0Y6ScxAgIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKJwNWzoIWbNmpVzmSeffDKvbZ1yyil5lSs2ra2tOZfZuHFjt9SlGBkBARCFAAJAAAFQOoyAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCxUihh3jttddyLrN8+fK8tnXttdfmVa7YvPvuuzmX+e1vf9stdSlGRkAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoXI6UoTZ06Na9yCxcuDEfD+eefn3OZL33pSzmXGT16dM5l+K9FixZpjm5kBARAFAIIgJ4RQOn9RS6//PJQV1cXysrKDjpkkSRJuOeee0JtbW3o27dvmDhxYnj77be7ss4AlGIA7dq1KzuuPG/evEO+/8ADD4SHH344PProo+H1118PJ5xwQpg8eXLYs2dPV9QXgFKdhDBlypRsOZR09PPggw+Gu+66K1xxxRXZa48//niorq7ORkrTpk078hoDUBS69BzQhg0bQnNzc3bYrV1lZWUYO3ZsWLFixSHL7N27N7S2tnZaACh+XRpAafik0hHPgdLn7e99VENDQxZS7cvgwYO7skoAFKjos+Dq6+tDS0tLx7Jp06bYVQKgpwVQTU1N9rht27ZOr6fP29/7qPLy8lBRUdFpAaD4dWkADR06NAuaxsbGjtfSczrpbLhx48Z15aYAKLVZcDt37gxNTU2dJh6sWbMmVFVVhSFDhoRbb701fPe73w1nnHFGFkh333139p2hfC+NAkBxyjmA3nzzzXDxxRd3PJ8zZ072OH369DB//vxwxx13ZN8Vuummm8L27dvDhRdeGJYsWRL69OnTtTUHoEcrS9Iv7xSQ9JBdOhuO4jRhwoScy8yYMeOoXLgzdbS+MJ2e+zwaZfhfd955Z15N8dBDD+Vc5j//+Y9m/z/pxLKPO68ffRYcAKVJAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAnnE7BjgS7bfvyMVll1121Bq9d+/eR2U7ZWVlOZcpsAvXd4nVq1fnXGb27Nk5l0lvipkPV7buXkZAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKFyMFopk7d27OZV577bVuqQtHnxEQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCxUjJ28iRI3Muc8455+RcpqysLBSbXr1y/+y3f//+bqkLxGIEBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACicDFS8vbWW2/lXOaZZ57Jucy0adNyLtO7d++Qj7///e/haMjnAqtJkuRc5qyzzgr5OPnkk3Mu86c//SnnMn/+859zLkPxMAICIAoBBEDPCKDly5eHyy+/PNTV1WWHERYuXNjp/euvvz57/cDl0ksv7co6A1CKAbRr164wevToMG/evMOukwbO1q1bO5annnrqSOsJQKlPQpgyZUq2fJzy8vJQU1NzJPUCoMh1yzmgpUuXhoEDB2YzcGbOnBk++OCDw667d+/e0Nra2mkBoPh1eQClh98ef/zx0NjYGL7//e+HZcuWZSOmtra2Q67f0NAQKisrO5bBgwd3dZUAKIXvAR34nY1zzjknjBo1Kpx++unZqOiSSy45aP36+vowZ86cjufpCEgIARS/bp+GPWzYsDBgwIDQ1NR02PNFFRUVnRYAil+3B9DmzZuzc0C1tbXdvSkAivkQ3M6dOzuNZjZs2BDWrFkTqqqqsuW+++4LV199dTYLbv369eGOO+4Iw4cPD5MnT+7qugNQSgH05ptvhosvvrjjefv5m+nTp4dHHnkkrF27Nvzyl78M27dvz76sOmnSpPCd73wnO9QGAO3KknyucNiN0kkI6Ww4aJd+8TlXffv2zasBV65cWbANf+yxuc8ZeuKJJ/LaVnoU42jIp37ph116hpaWlo89r+9acABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQHHckhu62h//+EeNGkI4//zzC/aq1vl6+umnY1eBiIyAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAULkZ6lLzzzjs5l7n//vtzLvPiiy/mXOYf//hHKGQjRozIuUyfPn3y2taePXtyLjN9+vScy5SVleVc5qtf/WooZA899FDOZZYuXdotdaFnMAICIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFGUJUmShALS2toaKisrQ7Fpa2vLuUw+v5rVq1fnXGbdunUhHz/84Q9zLvP+++/nXGbhwoU5l6murg75WL9+fc5lLrrooqNyMdJ8+sP+/ftDPjZv3pxzmQsuuCDnMlu3bs25DD1HS0tLqKioOOz7RkAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIIpj42yW7nL22WfnXGbHjh15bWvVqlU5l3n++edzLrNnz56cy9TV1eVc5kjKFapf/epXeZW74YYburwu8FFGQABEIYAAKPwAamhoCGPGjAn9+vULAwcODFOnTj3oXjLp4ZJZs2aFk046KZx44onh6quvDtu2bevqegNQSgG0bNmyLFxWrlwZXnrppbBv374wadKksGvXro51Zs+eHV544YXw3HPPZetv2bIlXHXVVd1RdwBKZRLCkiVLOj2fP39+NhJKT0aPHz8+u/vdL37xi/Dkk0+GL3zhC9k6jz32WPj0pz+dhdb555/ftbUHoDTPAaWBk6qqqsoe0yBKR0UTJ07sWGfEiBFhyJAhYcWKFYf8GXv37s1uw33gAkDxyzuA0nvN33rrrdl94EeOHJm91tzcHHr37h369+/fad3q6ursvcOdV6qsrOxYBg8enG+VACiFAErPBb311lvh6aefPqIK1NfXZyOp9mXTpk1H9PMAKOIvot58881h8eLFYfny5WHQoEEdr9fU1IQPP/wwbN++vdMoKJ0Fl753KOXl5dkCQGnJaQSUJEkWPgsWLAivvPJKGDp0aKf3zz333HDccceFxsbGjtfSadobN24M48aN67paA1BaI6D0sFs6w23RokXZd4Haz+uk52769u2bPaaX8JgzZ042MaGioiLccsstWfiYAQdA3gH0yCOPZI8TJkzo9Ho61fr666/P/v3jH/849OrVK/sCajrDbfLkyeFnP/tZLpsBoASUJelxtQKSTsNOR1LFpq2tLecyBfaroQs1NTXlXCafCT/3339/yEf6dQo4UunEsvRI2OG4FhwAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAAdBz7ohK7rZs2ZJzmdraWk19lK1fvz7nMrt37865zPPPP3/UrmwNhcoICIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABE4WKkR8mll16ac5mf/vSnOZfp06dPzmXGjBkT8nHXXXflXOZf//pXOBqmTZuWV7lrr7025zLNzc15bQtKnREQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIiiLEmSJBSQ1tbWUFlZGbsaAByhlpaWUFFRcdj3jYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAAo/gBoaGsKYMWNCv379wsCBA8PUqVPDunXrOq0zYcKEUFZW1mmZMWNGV9cbgFIKoGXLloVZs2aFlStXhpdeeins27cvTJo0KezatavTejfeeGPYunVrx/LAAw90db0B6OGOzWXlJUuWdHo+f/78bCS0atWqMH78+I7Xjz/++FBTU9N1tQSg6PQ60tutpqqqqjq9/sQTT4QBAwaEkSNHhvr6+rB79+7D/oy9e/dmt+E+cAGgBCR5amtrS774xS8mF1xwQafXf/7znydLlixJ1q5dm/z6179OTjnllOTKK6887M+ZO3duklbDog30AX1AHwhF1QYtLS0fmyN5B9CMGTOSU089Ndm0adPHrtfY2JhVpKmp6ZDv79mzJ6tk+5L+vNiNZtEG+oA+oA+Ebg+gnM4Btbv55pvD4sWLw/Lly8OgQYM+dt2xY8dmj01NTeH0008/6P3y8vJsAaC05BRA6YjplltuCQsWLAhLly4NQ4cO/cQya9asyR5ra2vzryUApR1A6RTsJ598MixatCj7LlBzc3P2emVlZejbt29Yv3599v5ll10WTjrppLB27dowe/bsbIbcqFGjuuv/AEBPlMt5n8Md53vsscey9zdu3JiMHz8+qaqqSsrLy5Phw4cnt99++yceBzxQuq5jr46/6wP6gD4QenwbfNK+v+z/gqVgpNOw0xEVAD1b+lWdioqKw77vWnAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARFFwAZQkSewqAHAU9ucFF0A7duyIXQUAjsL+vCwpsCHH/v37w5YtW0K/fv1CWVlZp/daW1vD4MGDw6ZNm0JFRUUoVdpBO+gP/i4Kef+QxkoaPnV1daFXr8OPc44NBSat7KBBgz52nbRRSzmA2mkH7aA/+Lso1P1DZWXlJ65TcIfgACgNAgiAKHpUAJWXl4e5c+dmj6VMO2gH/cHfRTHsHwpuEgIApaFHjYAAKB4CCIAoBBAAUQggAKLoMQE0b968cNppp4U+ffqEsWPHhjfeeCOUmnvvvTe7OsSBy4gRI0KxW758ebj88suzb1Wn/+eFCxd2ej+dR3PPPfeE2tra0Ldv3zBx4sTw9ttvh1Jrh+uvv/6g/nHppZeGYtLQ0BDGjBmTXSll4MCBYerUqWHdunWd1tmzZ0+YNWtWOOmkk8KJJ54Yrr766rBt27ZQau0wYcKEg/rDjBkzQiHpEQH0zDPPhDlz5mRTC1evXh1Gjx4dJk+eHN57771Qas4+++ywdevWjuV3v/tdKHa7du3Kfufph5BDeeCBB8LDDz8cHn300fD666+HE044Iesf6Y6olNohlQbOgf3jqaeeCsVk2bJlWbisXLkyvPTSS2Hfvn1h0qRJWdu0mz17dnjhhRfCc889l62fXtrrqquuCqXWDqkbb7yxU39I/1YKStIDnHfeecmsWbM6nre1tSV1dXVJQ0NDUkrmzp2bjB49OillaZddsGBBx/P9+/cnNTU1yQ9+8IOO17Zv356Ul5cnTz31VFIq7ZCaPn16csUVVySl5L333svaYtmyZR2/++OOOy557rnnOtb561//mq2zYsWKpFTaIfX5z38++cY3vpEUsoIfAX344Ydh1apV2WGVA68Xlz5fsWJFKDXpoaX0EMywYcPCddddFzZu3BhK2YYNG0Jzc3On/pFegyo9TFuK/WPp0qXZIZmzzjorzJw5M3zwwQehmLW0tGSPVVVV2WO6r0hHAwf2h/Qw9ZAhQ4q6P7R8pB3aPfHEE2HAgAFh5MiRob6+PuzevTsUkoK7GOlHvf/++6GtrS1UV1d3ej19/re//S2UknSnOn/+/Gznkg6n77vvvnDRRReFt956KzsWXIrS8Ekdqn+0v1cq0sNv6aGmoUOHhvXr14dvfetbYcqUKdmO95hjjgnFJr1y/q233houuOCCbAebSn/nvXv3Dv379y+Z/rD/EO2Q+spXvhJOPfXU7APr2rVrw5133pmdJ/rNb34TCkXBBxD/le5M2o0aNSoLpLSDPfvss+GGG27QVCVu2rRpHf8+55xzsj5y+umnZ6OiSy65JBSb9BxI+uGrFM6D5tMON910U6f+kE7SSftB+uEk7ReFoOAPwaXDx/TT20dnsaTPa2pqQilLP+WdeeaZoampKZSq9j6gfxwsPUyb/v0UY/+4+eabw+LFi8Orr77a6fYtaX9ID9tv3769JPYXNx+mHQ4l/cCaKqT+UPABlA6nzz333NDY2NhpyJk+HzduXChlO3fuzD7NpJ9sSlV6uCndsRzYP9IbcqWz4Uq9f2zevDk7B1RM/SOdf5HudBcsWBBeeeWV7Pd/oHRfcdxxx3XqD+lhp/RcaTH1h+QT2uFQ1qxZkz0WVH9IeoCnn346m9U0f/785C9/+Uty0003Jf3790+am5uTUvLNb34zWbp0abJhw4bktddeSyZOnJgMGDAgmwFTzHbs2JH84Q9/yJa0y/7oRz/K/v3uu+9m73/ve9/L+sOiRYuStWvXZjPBhg4dmvz73/9OSqUd0vduu+22bKZX2j9efvnl5LOf/WxyxhlnJHv27EmKxcyZM5PKysrs72Dr1q0dy+7duzvWmTFjRjJkyJDklVdeSd58881k3Lhx2VJMZn5COzQ1NSXf/va3s/9/2h/Sv41hw4Yl48ePTwpJjwig1E9+8pOsU/Xu3Tublr1y5cqk1FxzzTVJbW1t1gannHJK9jztaMXu1VdfzXa4H13SacftU7HvvvvupLq6OvugcskllyTr1q1LSqkd0h3PpEmTkpNPPjmbhnzqqacmN954Y9F9SDvU/z9dHnvssY510g8eX//615NPfepTyfHHH59ceeWV2c65lNph48aNWdhUVVVlfxPDhw9Pbr/99qSlpSUpJG7HAEAUBX8OCIDiJIAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgxPA/j5nHnP7r1T8AAAAASUVORK5CYII=", | |
| "text/plain": [ | |
| "<Figure size 640x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "def test_prediction(index, W1, b1, W2, b2, X_train):\n", | |
| " current_image = X_train[:, index, None]\n", | |
| " prediction = make_predictions(X_train[:, index, None], W1, b1, W2, b2)\n", | |
| " label = y_train[index]\n", | |
| " print(\"Prediction: \", prediction)\n", | |
| " print(\"Label: \", label)\n", | |
| " \n", | |
| " current_image = current_image.reshape((28, 28)) * 255\n", | |
| " plt.gray()\n", | |
| " plt.imshow(current_image, interpolation='nearest')\n", | |
| " plt.show()\n", | |
| "\n", | |
| "i = np.random.randint(0, X_train_scaled.shape[1])\n", | |
| "test_prediction(i, W1, b1, W2, b2, X_train_scaled)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "55d784fb", | |
| "metadata": {}, | |
| "source": [ | |
| "# Extra: sklearn MLPClassifier" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "8a6e516a", | |
| "metadata": {}, | |
| "source": [ | |
| "https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "aa6cfd27", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iteration 1, loss = 2.35732861\n", | |
| "Iteration 2, loss = 2.30002841\n", | |
| "Iteration 3, loss = 2.26468002\n", | |
| "Iteration 4, loss = 2.23664851\n", | |
| "Iteration 5, loss = 2.21013676\n", | |
| "Iteration 6, loss = 2.18399348\n", | |
| "Iteration 7, loss = 2.15824992\n", | |
| "Iteration 8, loss = 2.13285909\n", | |
| "Iteration 9, loss = 2.10772348\n", | |
| "Iteration 10, loss = 2.08276116\n", | |
| "Iteration 11, loss = 2.05785163\n", | |
| "Iteration 12, loss = 2.03291732\n", | |
| "Iteration 13, loss = 2.00782073\n", | |
| "Iteration 14, loss = 1.98242809\n", | |
| "Iteration 15, loss = 1.95666368\n", | |
| "Iteration 16, loss = 1.93047865\n", | |
| "Iteration 17, loss = 1.90384422\n", | |
| "Iteration 18, loss = 1.87670576\n", | |
| "Iteration 19, loss = 1.84903195\n", | |
| "Iteration 20, loss = 1.82080313\n", | |
| "Iteration 21, loss = 1.79196680\n", | |
| "Iteration 22, loss = 1.76255879\n", | |
| "Iteration 23, loss = 1.73259370\n", | |
| "Iteration 24, loss = 1.70214221\n", | |
| "Iteration 25, loss = 1.67141160\n", | |
| "Iteration 26, loss = 1.64054857\n", | |
| "Iteration 27, loss = 1.60967903\n", | |
| "Iteration 28, loss = 1.57901073\n", | |
| "Iteration 29, loss = 1.54863661\n", | |
| "Iteration 30, loss = 1.51861898\n", | |
| "Iteration 31, loss = 1.48902494\n", | |
| "Iteration 32, loss = 1.45991991\n", | |
| "Iteration 33, loss = 1.43139716\n", | |
| "Iteration 34, loss = 1.40352680\n", | |
| "Iteration 35, loss = 1.37630772\n", | |
| "Iteration 36, loss = 1.34976963\n", | |
| "Iteration 37, loss = 1.32394366\n", | |
| "Iteration 38, loss = 1.29879471\n", | |
| "Iteration 39, loss = 1.27433017\n", | |
| "Iteration 40, loss = 1.25054599\n", | |
| "Iteration 41, loss = 1.22741673\n", | |
| "Iteration 42, loss = 1.20491746\n", | |
| "Iteration 43, loss = 1.18304326\n", | |
| "Iteration 44, loss = 1.16177506\n", | |
| "Iteration 45, loss = 1.14109908\n", | |
| "Iteration 46, loss = 1.12100476\n", | |
| "Iteration 47, loss = 1.10147927\n", | |
| "Iteration 48, loss = 1.08251611\n", | |
| "Iteration 49, loss = 1.06409965\n", | |
| "Iteration 50, loss = 1.04621780\n", | |
| "Iteration 51, loss = 1.02886320\n", | |
| "Iteration 52, loss = 1.01202833\n", | |
| "Iteration 53, loss = 0.99570439\n", | |
| "Iteration 54, loss = 0.97988206\n", | |
| "Iteration 55, loss = 0.96455341\n", | |
| "Iteration 56, loss = 0.94970336\n", | |
| "Iteration 57, loss = 0.93532258\n", | |
| "Iteration 58, loss = 0.92140143\n", | |
| "Iteration 59, loss = 0.90793103\n", | |
| "Iteration 60, loss = 0.89489579\n", | |
| "Iteration 61, loss = 0.88228555\n", | |
| "Iteration 62, loss = 0.87008118\n", | |
| "Iteration 63, loss = 0.85827249\n", | |
| "Iteration 64, loss = 0.84684464\n", | |
| "Iteration 65, loss = 0.83578889\n", | |
| "Iteration 66, loss = 0.82508825\n", | |
| "Iteration 67, loss = 0.81473054\n", | |
| "Iteration 68, loss = 0.80470465\n", | |
| "Iteration 69, loss = 0.79499843\n", | |
| "Iteration 70, loss = 0.78559988\n", | |
| "Iteration 71, loss = 0.77649974\n", | |
| "Iteration 72, loss = 0.76768648\n", | |
| "Iteration 73, loss = 0.75914704\n", | |
| "Iteration 74, loss = 0.75087207\n", | |
| "Iteration 75, loss = 0.74285110\n", | |
| "Iteration 76, loss = 0.73507294\n", | |
| "Iteration 77, loss = 0.72752922\n", | |
| "Iteration 78, loss = 0.72021077\n", | |
| "Iteration 79, loss = 0.71310959\n", | |
| "Iteration 80, loss = 0.70621668\n", | |
| "Iteration 81, loss = 0.69952432\n", | |
| "Iteration 82, loss = 0.69302505\n", | |
| "Iteration 83, loss = 0.68671011\n", | |
| "Iteration 84, loss = 0.68057239\n", | |
| "Iteration 85, loss = 0.67460546\n", | |
| "Iteration 86, loss = 0.66880347\n", | |
| "Iteration 87, loss = 0.66316041\n", | |
| "Iteration 88, loss = 0.65766937\n", | |
| "Iteration 89, loss = 0.65232467\n", | |
| "Iteration 90, loss = 0.64712191\n", | |
| "Iteration 91, loss = 0.64205608\n", | |
| "Iteration 92, loss = 0.63712215\n", | |
| "Iteration 93, loss = 0.63231484\n", | |
| "Iteration 94, loss = 0.62762900\n", | |
| "Iteration 95, loss = 0.62306156\n", | |
| "Iteration 96, loss = 0.61860839\n", | |
| "Iteration 97, loss = 0.61426516\n", | |
| "Iteration 98, loss = 0.61002817\n", | |
| "Iteration 99, loss = 0.60589379\n", | |
| "Iteration 100, loss = 0.60185840\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "c:\\Users\\Bruno\\miniconda3\\envs\\cat\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:785: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.\n", | |
| " warnings.warn(\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style>#sk-container-id-1 {\n", | |
| " /* Definition of color scheme common for light and dark mode */\n", | |
| " --sklearn-color-text: #000;\n", | |
| " --sklearn-color-text-muted: #666;\n", | |
| " --sklearn-color-line: gray;\n", | |
| " /* Definition of color scheme for unfitted estimators */\n", | |
| " --sklearn-color-unfitted-level-0: #fff5e6;\n", | |
| " --sklearn-color-unfitted-level-1: #f6e4d2;\n", | |
| " --sklearn-color-unfitted-level-2: #ffe0b3;\n", | |
| " --sklearn-color-unfitted-level-3: chocolate;\n", | |
| " /* Definition of color scheme for fitted estimators */\n", | |
| " --sklearn-color-fitted-level-0: #f0f8ff;\n", | |
| " --sklearn-color-fitted-level-1: #d4ebff;\n", | |
| " --sklearn-color-fitted-level-2: #b3dbfd;\n", | |
| " --sklearn-color-fitted-level-3: cornflowerblue;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1.light {\n", | |
| " /* Specific color for light theme */\n", | |
| " --sklearn-color-text-on-default-background: black;\n", | |
| " --sklearn-color-background: white;\n", | |
| " --sklearn-color-border-box: black;\n", | |
| " --sklearn-color-icon: #696969;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1.dark {\n", | |
| " --sklearn-color-text-on-default-background: white;\n", | |
| " --sklearn-color-background: #111;\n", | |
| " --sklearn-color-border-box: white;\n", | |
| " --sklearn-color-icon: #878787;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 {\n", | |
| " color: var(--sklearn-color-text);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 pre {\n", | |
| " padding: 0;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 input.sk-hidden--visually {\n", | |
| " border: 0;\n", | |
| " clip: rect(1px 1px 1px 1px);\n", | |
| " clip: rect(1px, 1px, 1px, 1px);\n", | |
| " height: 1px;\n", | |
| " margin: -1px;\n", | |
| " overflow: hidden;\n", | |
| " padding: 0;\n", | |
| " position: absolute;\n", | |
| " width: 1px;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-dashed-wrapped {\n", | |
| " border: 1px dashed var(--sklearn-color-line);\n", | |
| " margin: 0 0.4em 0.5em 0.4em;\n", | |
| " box-sizing: border-box;\n", | |
| " padding-bottom: 0.4em;\n", | |
| " background-color: var(--sklearn-color-background);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-container {\n", | |
| " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", | |
| " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", | |
| " so we also need the `!important` here to be able to override the\n", | |
| " default hidden behavior on the sphinx rendered scikit-learn.org.\n", | |
| " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", | |
| " display: inline-block !important;\n", | |
| " position: relative;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-text-repr-fallback {\n", | |
| " display: none;\n", | |
| "}\n", | |
| "\n", | |
| "div.sk-parallel-item,\n", | |
| "div.sk-serial,\n", | |
| "div.sk-item {\n", | |
| " /* draw centered vertical line to link estimators */\n", | |
| " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", | |
| " background-size: 2px 100%;\n", | |
| " background-repeat: no-repeat;\n", | |
| " background-position: center center;\n", | |
| "}\n", | |
| "\n", | |
| "/* Parallel-specific style estimator block */\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-parallel-item::after {\n", | |
| " content: \"\";\n", | |
| " width: 100%;\n", | |
| " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", | |
| " flex-grow: 1;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-parallel {\n", | |
| " display: flex;\n", | |
| " align-items: stretch;\n", | |
| " justify-content: center;\n", | |
| " background-color: var(--sklearn-color-background);\n", | |
| " position: relative;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-parallel-item {\n", | |
| " display: flex;\n", | |
| " flex-direction: column;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n", | |
| " align-self: flex-end;\n", | |
| " width: 50%;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n", | |
| " align-self: flex-start;\n", | |
| " width: 50%;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n", | |
| " width: 0;\n", | |
| "}\n", | |
| "\n", | |
| "/* Serial-specific style estimator block */\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-serial {\n", | |
| " display: flex;\n", | |
| " flex-direction: column;\n", | |
| " align-items: center;\n", | |
| " background-color: var(--sklearn-color-background);\n", | |
| " padding-right: 1em;\n", | |
| " padding-left: 1em;\n", | |
| "}\n", | |
| "\n", | |
| "\n", | |
| "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", | |
| "clickable and can be expanded/collapsed.\n", | |
| "- Pipeline and ColumnTransformer use this feature and define the default style\n", | |
| "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", | |
| "*/\n", | |
| "\n", | |
| "/* Pipeline and ColumnTransformer style (default) */\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-toggleable {\n", | |
| " /* Default theme specific background. It is overwritten whether we have a\n", | |
| " specific estimator or a Pipeline/ColumnTransformer */\n", | |
| " background-color: var(--sklearn-color-background);\n", | |
| "}\n", | |
| "\n", | |
| "/* Toggleable label */\n", | |
| "#sk-container-id-1 label.sk-toggleable__label {\n", | |
| " cursor: pointer;\n", | |
| " display: flex;\n", | |
| " width: 100%;\n", | |
| " margin-bottom: 0;\n", | |
| " padding: 0.5em;\n", | |
| " box-sizing: border-box;\n", | |
| " text-align: center;\n", | |
| " align-items: center;\n", | |
| " justify-content: center;\n", | |
| " gap: 0.5em;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 label.sk-toggleable__label .caption {\n", | |
| " font-size: 0.6rem;\n", | |
| " font-weight: lighter;\n", | |
| " color: var(--sklearn-color-text-muted);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n", | |
| " /* Arrow on the left of the label */\n", | |
| " content: \"▸\";\n", | |
| " float: left;\n", | |
| " margin-right: 0.25em;\n", | |
| " color: var(--sklearn-color-icon);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n", | |
| " color: var(--sklearn-color-text);\n", | |
| "}\n", | |
| "\n", | |
| "/* Toggleable content - dropdown */\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-toggleable__content {\n", | |
| " display: none;\n", | |
| " text-align: left;\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-0);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-toggleable__content.fitted {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-0);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-toggleable__content pre {\n", | |
| " margin: 0.2em;\n", | |
| " border-radius: 0.25em;\n", | |
| " color: var(--sklearn-color-text);\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-0);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-0);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", | |
| " /* Expand drop-down */\n", | |
| " display: block;\n", | |
| " width: 100%;\n", | |
| " overflow: visible;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", | |
| " content: \"▾\";\n", | |
| "}\n", | |
| "\n", | |
| "/* Pipeline/ColumnTransformer-specific style */\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", | |
| " color: var(--sklearn-color-text);\n", | |
| " background-color: var(--sklearn-color-unfitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", | |
| " background-color: var(--sklearn-color-fitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "/* Estimator-specific style */\n", | |
| "\n", | |
| "/* Colorize estimator box */\n", | |
| "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n", | |
| "#sk-container-id-1 div.sk-label label {\n", | |
| " /* The background is the default theme color */\n", | |
| " color: var(--sklearn-color-text-on-default-background);\n", | |
| "}\n", | |
| "\n", | |
| "/* On hover, darken the color of the background */\n", | |
| "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n", | |
| " color: var(--sklearn-color-text);\n", | |
| " background-color: var(--sklearn-color-unfitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "/* Label box, darken color on hover, fitted */\n", | |
| "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", | |
| " color: var(--sklearn-color-text);\n", | |
| " background-color: var(--sklearn-color-fitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "/* Estimator label */\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-label label {\n", | |
| " font-family: monospace;\n", | |
| " font-weight: bold;\n", | |
| " line-height: 1.2em;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-label-container {\n", | |
| " text-align: center;\n", | |
| "}\n", | |
| "\n", | |
| "/* Estimator-specific */\n", | |
| "#sk-container-id-1 div.sk-estimator {\n", | |
| " font-family: monospace;\n", | |
| " border: 1px dotted var(--sklearn-color-border-box);\n", | |
| " border-radius: 0.25em;\n", | |
| " box-sizing: border-box;\n", | |
| " margin-bottom: 0.5em;\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-0);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-estimator.fitted {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-0);\n", | |
| "}\n", | |
| "\n", | |
| "/* on hover */\n", | |
| "#sk-container-id-1 div.sk-estimator:hover {\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 div.sk-estimator.fitted:hover {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-2);\n", | |
| "}\n", | |
| "\n", | |
| "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", | |
| "\n", | |
| "/* Common style for \"i\" and \"?\" */\n", | |
| "\n", | |
| ".sk-estimator-doc-link,\n", | |
| "a:link.sk-estimator-doc-link,\n", | |
| "a:visited.sk-estimator-doc-link {\n", | |
| " float: right;\n", | |
| " font-size: smaller;\n", | |
| " line-height: 1em;\n", | |
| " font-family: monospace;\n", | |
| " background-color: var(--sklearn-color-unfitted-level-0);\n", | |
| " border-radius: 1em;\n", | |
| " height: 1em;\n", | |
| " width: 1em;\n", | |
| " text-decoration: none !important;\n", | |
| " margin-left: 0.5em;\n", | |
| " text-align: center;\n", | |
| " /* unfitted */\n", | |
| " border: var(--sklearn-color-unfitted-level-3) 1pt solid;\n", | |
| " color: var(--sklearn-color-unfitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| ".sk-estimator-doc-link.fitted,\n", | |
| "a:link.sk-estimator-doc-link.fitted,\n", | |
| "a:visited.sk-estimator-doc-link.fitted {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-0);\n", | |
| " border: var(--sklearn-color-fitted-level-3) 1pt solid;\n", | |
| " color: var(--sklearn-color-fitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| "/* On hover */\n", | |
| "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", | |
| ".sk-estimator-doc-link:hover,\n", | |
| "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", | |
| ".sk-estimator-doc-link:hover {\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-3);\n", | |
| " border: var(--sklearn-color-fitted-level-0) 1pt solid;\n", | |
| " color: var(--sklearn-color-unfitted-level-0);\n", | |
| " text-decoration: none;\n", | |
| "}\n", | |
| "\n", | |
| "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", | |
| ".sk-estimator-doc-link.fitted:hover,\n", | |
| "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", | |
| ".sk-estimator-doc-link.fitted:hover {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-3);\n", | |
| " border: var(--sklearn-color-fitted-level-0) 1pt solid;\n", | |
| " color: var(--sklearn-color-fitted-level-0);\n", | |
| " text-decoration: none;\n", | |
| "}\n", | |
| "\n", | |
| "/* Span, style for the box shown on hovering the info icon */\n", | |
| ".sk-estimator-doc-link span {\n", | |
| " display: none;\n", | |
| " z-index: 9999;\n", | |
| " position: relative;\n", | |
| " font-weight: normal;\n", | |
| " right: .2ex;\n", | |
| " padding: .5ex;\n", | |
| " margin: .5ex;\n", | |
| " width: min-content;\n", | |
| " min-width: 20ex;\n", | |
| " max-width: 50ex;\n", | |
| " color: var(--sklearn-color-text);\n", | |
| " box-shadow: 2pt 2pt 4pt #999;\n", | |
| " /* unfitted */\n", | |
| " background: var(--sklearn-color-unfitted-level-0);\n", | |
| " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| ".sk-estimator-doc-link.fitted span {\n", | |
| " /* fitted */\n", | |
| " background: var(--sklearn-color-fitted-level-0);\n", | |
| " border: var(--sklearn-color-fitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| ".sk-estimator-doc-link:hover span {\n", | |
| " display: block;\n", | |
| "}\n", | |
| "\n", | |
| "/* \"?\"-specific style due to the `<a>` HTML tag */\n", | |
| "\n", | |
| "#sk-container-id-1 a.estimator_doc_link {\n", | |
| " float: right;\n", | |
| " font-size: 1rem;\n", | |
| " line-height: 1em;\n", | |
| " font-family: monospace;\n", | |
| " background-color: var(--sklearn-color-unfitted-level-0);\n", | |
| " border-radius: 1rem;\n", | |
| " height: 1rem;\n", | |
| " width: 1rem;\n", | |
| " text-decoration: none;\n", | |
| " /* unfitted */\n", | |
| " color: var(--sklearn-color-unfitted-level-1);\n", | |
| " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 a.estimator_doc_link.fitted {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-0);\n", | |
| " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", | |
| " color: var(--sklearn-color-fitted-level-1);\n", | |
| "}\n", | |
| "\n", | |
| "/* On hover */\n", | |
| "#sk-container-id-1 a.estimator_doc_link:hover {\n", | |
| " /* unfitted */\n", | |
| " background-color: var(--sklearn-color-unfitted-level-3);\n", | |
| " color: var(--sklearn-color-background);\n", | |
| " text-decoration: none;\n", | |
| "}\n", | |
| "\n", | |
| "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n", | |
| " /* fitted */\n", | |
| " background-color: var(--sklearn-color-fitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table {\n", | |
| " font-family: monospace;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table summary {\n", | |
| " padding: .5rem;\n", | |
| " cursor: pointer;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table summary::marker {\n", | |
| " font-size: 0.7rem;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table details[open] {\n", | |
| " padding-left: 0.1rem;\n", | |
| " padding-right: 0.1rem;\n", | |
| " padding-bottom: 0.3rem;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table .parameters-table {\n", | |
| " margin-left: auto !important;\n", | |
| " margin-right: auto !important;\n", | |
| " margin-top: 0;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table .parameters-table tr:nth-child(odd) {\n", | |
| " background-color: #fff;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table .parameters-table tr:nth-child(even) {\n", | |
| " background-color: #f6f6f6;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table .parameters-table tr:hover {\n", | |
| " background-color: #e0e0e0;\n", | |
| "}\n", | |
| "\n", | |
| ".estimator-table table td {\n", | |
| " border: 1px solid rgba(106, 105, 104, 0.232);\n", | |
| "}\n", | |
| "\n", | |
| "/*\n", | |
| " `table td`is set in notebook with right text-align.\n", | |
| " We need to overwrite it.\n", | |
| "*/\n", | |
| ".estimator-table table td.param {\n", | |
| " text-align: left;\n", | |
| " position: relative;\n", | |
| " padding: 0;\n", | |
| "}\n", | |
| "\n", | |
| ".user-set td {\n", | |
| " color:rgb(255, 94, 0);\n", | |
| " text-align: left !important;\n", | |
| "}\n", | |
| "\n", | |
| ".user-set td.value {\n", | |
| " color:rgb(255, 94, 0);\n", | |
| " background-color: transparent;\n", | |
| "}\n", | |
| "\n", | |
| ".default td {\n", | |
| " color: black;\n", | |
| " text-align: left !important;\n", | |
| "}\n", | |
| "\n", | |
| ".user-set td i,\n", | |
| ".default td i {\n", | |
| " color: black;\n", | |
| "}\n", | |
| "\n", | |
| "/*\n", | |
| " Styles for parameter documentation links\n", | |
| " We need styling for visited so jupyter doesn't overwrite it\n", | |
| "*/\n", | |
| "a.param-doc-link,\n", | |
| "a.param-doc-link:link,\n", | |
| "a.param-doc-link:visited {\n", | |
| " text-decoration: underline dashed;\n", | |
| " text-underline-offset: .3em;\n", | |
| " color: inherit;\n", | |
| " display: block;\n", | |
| " padding: .5em;\n", | |
| "}\n", | |
| "\n", | |
| "/* \"hack\" to make the entire area of the cell containing the link clickable */\n", | |
| "a.param-doc-link::before {\n", | |
| " position: absolute;\n", | |
| " content: \"\";\n", | |
| " inset: 0;\n", | |
| "}\n", | |
| "\n", | |
| ".param-doc-description {\n", | |
| " display: none;\n", | |
| " position: absolute;\n", | |
| " z-index: 9999;\n", | |
| " left: 0;\n", | |
| " padding: .5ex;\n", | |
| " margin-left: 1.5em;\n", | |
| " color: var(--sklearn-color-text);\n", | |
| " box-shadow: .3em .3em .4em #999;\n", | |
| " width: max-content;\n", | |
| " text-align: left;\n", | |
| " max-height: 10em;\n", | |
| " overflow-y: auto;\n", | |
| "\n", | |
| " /* unfitted */\n", | |
| " background: var(--sklearn-color-unfitted-level-0);\n", | |
| " border: thin solid var(--sklearn-color-unfitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| "/* Fitted state for parameter tooltips */\n", | |
| ".fitted .param-doc-description {\n", | |
| " /* fitted */\n", | |
| " background: var(--sklearn-color-fitted-level-0);\n", | |
| " border: thin solid var(--sklearn-color-fitted-level-3);\n", | |
| "}\n", | |
| "\n", | |
| ".param-doc-link:hover .param-doc-description {\n", | |
| " display: block;\n", | |
| "}\n", | |
| "\n", | |
| ".copy-paste-icon {\n", | |
| " background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=);\n", | |
| " background-repeat: no-repeat;\n", | |
| " background-size: 14px 14px;\n", | |
| " background-position: 0;\n", | |
| " display: inline-block;\n", | |
| " width: 14px;\n", | |
| " height: 14px;\n", | |
| " cursor: pointer;\n", | |
| "}\n", | |
| "</style><body><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>MLPClassifier(alpha=0.0, batch_size=69000, hidden_layer_sizes=12,\n", | |
| " learning_rate_init=0.1, max_iter=100, momentum=0.0,\n", | |
| " random_state=42, solver='sgd', verbose=50)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>MLPClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html\">?<span>Documentation for MLPClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\" data-param-prefix=\"\">\n", | |
| " <div class=\"estimator-table\">\n", | |
| " <details>\n", | |
| " <summary>Parameters</summary>\n", | |
| " <table class=\"parameters-table\">\n", | |
| " <tbody>\n", | |
| " \n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('hidden_layer_sizes',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=hidden_layer_sizes,-array-like%20of%20shape%28n_layers%20-%202%2C%29%2C%20default%3D%28100%2C%29\">\n", | |
| " hidden_layer_sizes\n", | |
| " <span class=\"param-doc-description\">hidden_layer_sizes: array-like of shape(n_layers - 2,), default=(100,)<br><br>The ith element represents the number of neurons in the ith<br>hidden layer.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">12</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('activation',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=activation,-%7B%27identity%27%2C%20%27logistic%27%2C%20%27tanh%27%2C%20%27relu%27%7D%2C%20default%3D%27relu%27\">\n", | |
| " activation\n", | |
| " <span class=\"param-doc-description\">activation: {'identity', 'logistic', 'tanh', 'relu'}, default='relu'<br><br>Activation function for the hidden layer.<br><br>- 'identity', no-op activation, useful to implement linear bottleneck,<br> returns f(x) = x<br><br>- 'logistic', the logistic sigmoid function,<br> returns f(x) = 1 / (1 + exp(-x)).<br><br>- 'tanh', the hyperbolic tan function,<br> returns f(x) = tanh(x).<br><br>- 'relu', the rectified linear unit function,<br> returns f(x) = max(0, x)</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">'relu'</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('solver',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=solver,-%7B%27lbfgs%27%2C%20%27sgd%27%2C%20%27adam%27%7D%2C%20default%3D%27adam%27\">\n", | |
| " solver\n", | |
| " <span class=\"param-doc-description\">solver: {'lbfgs', 'sgd', 'adam'}, default='adam'<br><br>The solver for weight optimization.<br><br>- 'lbfgs' is an optimizer in the family of quasi-Newton methods.<br><br>- 'sgd' refers to stochastic gradient descent.<br><br>- 'adam' refers to a stochastic gradient-based optimizer proposed<br> by Kingma, Diederik, and Jimmy Ba<br><br>For a comparison between Adam optimizer and SGD, see<br>:ref:`sphx_glr_auto_examples_neural_networks_plot_mlp_training_curves.py`.<br><br>Note: The default solver 'adam' works pretty well on relatively<br>large datasets (with thousands of training samples or more) in terms of<br>both training time and validation score.<br>For small datasets, however, 'lbfgs' can converge faster and perform<br>better.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">'sgd'</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('alpha',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=alpha,-float%2C%20default%3D0.0001\">\n", | |
| " alpha\n", | |
| " <span class=\"param-doc-description\">alpha: float, default=0.0001<br><br>Strength of the L2 regularization term. The L2 regularization term<br>is divided by the sample size when added to the loss.<br><br>For an example usage and visualization of varying regularization, see<br>:ref:`sphx_glr_auto_examples_neural_networks_plot_mlp_alpha.py`.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.0</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('batch_size',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=batch_size,-int%2C%20default%3D%27auto%27\">\n", | |
| " batch_size\n", | |
| " <span class=\"param-doc-description\">batch_size: int, default='auto'<br><br>Size of minibatches for stochastic optimizers.<br>If the solver is 'lbfgs', the classifier will not use minibatch.<br>When set to \"auto\", `batch_size=min(200, n_samples)`.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">69000</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('learning_rate',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=learning_rate,-%7B%27constant%27%2C%20%27invscaling%27%2C%20%27adaptive%27%7D%2C%20default%3D%27constant%27\">\n", | |
| " learning_rate\n", | |
| " <span class=\"param-doc-description\">learning_rate: {'constant', 'invscaling', 'adaptive'}, default='constant'<br><br>Learning rate schedule for weight updates.<br><br>- 'constant' is a constant learning rate given by<br> 'learning_rate_init'.<br><br>- 'invscaling' gradually decreases the learning rate at each<br> time step 't' using an inverse scaling exponent of 'power_t'.<br> effective_learning_rate = learning_rate_init / pow(t, power_t)<br><br>- 'adaptive' keeps the learning rate constant to<br> 'learning_rate_init' as long as training loss keeps decreasing.<br> Each time two consecutive epochs fail to decrease training loss by at<br> least tol, or fail to increase validation score by at least tol if<br> 'early_stopping' is on, the current learning rate is divided by 5.<br><br>Only used when ``solver='sgd'``.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">'constant'</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('learning_rate_init',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=learning_rate_init,-float%2C%20default%3D0.001\">\n", | |
| " learning_rate_init\n", | |
| " <span class=\"param-doc-description\">learning_rate_init: float, default=0.001<br><br>The initial learning rate used. It controls the step-size<br>in updating the weights. Only used when solver='sgd' or 'adam'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.1</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('power_t',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=power_t,-float%2C%20default%3D0.5\">\n", | |
| " power_t\n", | |
| " <span class=\"param-doc-description\">power_t: float, default=0.5<br><br>The exponent for inverse scaling learning rate.<br>It is used in updating effective learning rate when the learning_rate<br>is set to 'invscaling'. Only used when solver='sgd'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.5</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('max_iter',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=max_iter,-int%2C%20default%3D200\">\n", | |
| " max_iter\n", | |
| " <span class=\"param-doc-description\">max_iter: int, default=200<br><br>Maximum number of iterations. The solver iterates until convergence<br>(determined by 'tol') or this number of iterations. For stochastic<br>solvers ('sgd', 'adam'), note that this determines the number of epochs<br>(how many times each data point will be used), not the number of<br>gradient steps.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">100</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('shuffle',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=shuffle,-bool%2C%20default%3DTrue\">\n", | |
| " shuffle\n", | |
| " <span class=\"param-doc-description\">shuffle: bool, default=True<br><br>Whether to shuffle samples in each iteration. Only used when<br>solver='sgd' or 'adam'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">True</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('random_state',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=random_state,-int%2C%20RandomState%20instance%2C%20default%3DNone\">\n", | |
| " random_state\n", | |
| " <span class=\"param-doc-description\">random_state: int, RandomState instance, default=None<br><br>Determines random number generation for weights and bias<br>initialization, train-test split if early stopping is used, and batch<br>sampling when solver='sgd' or 'adam'.<br>Pass an int for reproducible results across multiple function calls.<br>See :term:`Glossary <random_state>`.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">42</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('tol',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=tol,-float%2C%20default%3D1e-4\">\n", | |
| " tol\n", | |
| " <span class=\"param-doc-description\">tol: float, default=1e-4<br><br>Tolerance for the optimization. When the loss or score is not improving<br>by at least ``tol`` for ``n_iter_no_change`` consecutive iterations,<br>unless ``learning_rate`` is set to 'adaptive', convergence is<br>considered to be reached and training stops.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.0001</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('verbose',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=verbose,-bool%2C%20default%3DFalse\">\n", | |
| " verbose\n", | |
| " <span class=\"param-doc-description\">verbose: bool, default=False<br><br>Whether to print progress messages to stdout.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">50</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('warm_start',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=warm_start,-bool%2C%20default%3DFalse\">\n", | |
| " warm_start\n", | |
| " <span class=\"param-doc-description\">warm_start: bool, default=False<br><br>When set to True, reuse the solution of the previous<br>call to fit as initialization, otherwise, just erase the<br>previous solution. See :term:`the Glossary <warm_start>`.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">False</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"user-set\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('momentum',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=momentum,-float%2C%20default%3D0.9\">\n", | |
| " momentum\n", | |
| " <span class=\"param-doc-description\">momentum: float, default=0.9<br><br>Momentum for gradient descent update. Should be between 0 and 1. Only<br>used when solver='sgd'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.0</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('nesterovs_momentum',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=nesterovs_momentum,-bool%2C%20default%3DTrue\">\n", | |
| " nesterovs_momentum\n", | |
| " <span class=\"param-doc-description\">nesterovs_momentum: bool, default=True<br><br>Whether to use Nesterov's momentum. Only used when solver='sgd' and<br>momentum > 0.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">True</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('early_stopping',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=early_stopping,-bool%2C%20default%3DFalse\">\n", | |
| " early_stopping\n", | |
| " <span class=\"param-doc-description\">early_stopping: bool, default=False<br><br>Whether to use early stopping to terminate training when validation<br>score is not improving. If set to True, it will automatically set<br>aside ``validation_fraction`` of training data as validation and<br>terminate training when validation score is not improving by at least<br>``tol`` for ``n_iter_no_change`` consecutive epochs. The split is<br>stratified, except in a multilabel setting.<br>If early stopping is False, then the training stops when the training<br>loss does not improve by more than ``tol`` for ``n_iter_no_change``<br>consecutive passes over the training set.<br>Only effective when solver='sgd' or 'adam'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">False</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('validation_fraction',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=validation_fraction,-float%2C%20default%3D0.1\">\n", | |
| " validation_fraction\n", | |
| " <span class=\"param-doc-description\">validation_fraction: float, default=0.1<br><br>The proportion of training data to set aside as validation set for<br>early stopping. Must be between 0 and 1.<br>Only used if early_stopping is True.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.1</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('beta_1',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=beta_1,-float%2C%20default%3D0.9\">\n", | |
| " beta_1\n", | |
| " <span class=\"param-doc-description\">beta_1: float, default=0.9<br><br>Exponential decay rate for estimates of first moment vector in adam,<br>should be in [0, 1). Only used when solver='adam'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.9</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('beta_2',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=beta_2,-float%2C%20default%3D0.999\">\n", | |
| " beta_2\n", | |
| " <span class=\"param-doc-description\">beta_2: float, default=0.999<br><br>Exponential decay rate for estimates of second moment vector in adam,<br>should be in [0, 1). Only used when solver='adam'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">0.999</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('epsilon',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=epsilon,-float%2C%20default%3D1e-8\">\n", | |
| " epsilon\n", | |
| " <span class=\"param-doc-description\">epsilon: float, default=1e-8<br><br>Value for numerical stability in adam. Only used when solver='adam'.</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">1e-08</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('n_iter_no_change',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=n_iter_no_change,-int%2C%20default%3D10\">\n", | |
| " n_iter_no_change\n", | |
| " <span class=\"param-doc-description\">n_iter_no_change: int, default=10<br><br>Maximum number of epochs to not meet ``tol`` improvement.<br>Only effective when solver='sgd' or 'adam'.<br><br>.. versionadded:: 0.20</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">10</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| "\n", | |
| " <tr class=\"default\">\n", | |
| " <td><i class=\"copy-paste-icon\"\n", | |
| " onclick=\"copyToClipboard('max_fun',\n", | |
| " this.parentElement.nextElementSibling)\"\n", | |
| " ></i></td>\n", | |
| " <td class=\"param\">\n", | |
| " <a class=\"param-doc-link\"\n", | |
| " rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=max_fun,-int%2C%20default%3D15000\">\n", | |
| " max_fun\n", | |
| " <span class=\"param-doc-description\">max_fun: int, default=15000<br><br>Only used when solver='lbfgs'. Maximum number of loss function calls.<br>The solver iterates until convergence (determined by 'tol'), number<br>of iterations reaches max_iter, or this number of loss function calls.<br>Note that number of loss function calls will be greater than or equal<br>to the number of iterations for the `MLPClassifier`.<br><br>.. versionadded:: 0.22</span>\n", | |
| " </a>\n", | |
| " </td>\n", | |
| " <td class=\"value\">15000</td>\n", | |
| " </tr>\n", | |
| " \n", | |
| " </tbody>\n", | |
| " </table>\n", | |
| " </details>\n", | |
| " </div>\n", | |
| " </div></div></div></div></div><script>function copyToClipboard(text, element) {\n", | |
| " // Get the parameter prefix from the closest toggleable content\n", | |
| " const toggleableContent = element.closest('.sk-toggleable__content');\n", | |
| " const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';\n", | |
| " const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text;\n", | |
| "\n", | |
| " const originalStyle = element.style;\n", | |
| " const computedStyle = window.getComputedStyle(element);\n", | |
| " const originalWidth = computedStyle.width;\n", | |
| " const originalHTML = element.innerHTML.replace('Copied!', '');\n", | |
| "\n", | |
| " navigator.clipboard.writeText(fullParamName)\n", | |
| " .then(() => {\n", | |
| " element.style.width = originalWidth;\n", | |
| " element.style.color = 'green';\n", | |
| " element.innerHTML = \"Copied!\";\n", | |
| "\n", | |
| " setTimeout(() => {\n", | |
| " element.innerHTML = originalHTML;\n", | |
| " element.style = originalStyle;\n", | |
| " }, 2000);\n", | |
| " })\n", | |
| " .catch(err => {\n", | |
| " console.error('Failed to copy:', err);\n", | |
| " element.style.color = 'red';\n", | |
| " element.innerHTML = \"Failed!\";\n", | |
| " setTimeout(() => {\n", | |
| " element.innerHTML = originalHTML;\n", | |
| " element.style = originalStyle;\n", | |
| " }, 2000);\n", | |
| " });\n", | |
| " return false;\n", | |
| "}\n", | |
| "\n", | |
| "document.querySelectorAll('.copy-paste-icon').forEach(function(element) {\n", | |
| " const toggleableContent = element.closest('.sk-toggleable__content');\n", | |
| " const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';\n", | |
| " const paramName = element.parentElement.nextElementSibling\n", | |
| " .textContent.trim().split(' ')[0];\n", | |
| " const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName;\n", | |
| "\n", | |
| " element.setAttribute('title', fullParamName);\n", | |
| "});\n", | |
| "\n", | |
| "\n", | |
| "/**\n", | |
| " * Adapted from Skrub\n", | |
| " * https://github.com/skrub-data/skrub/blob/403466d1d5d4dc76a7ef569b3f8228db59a31dc3/skrub/_reporting/_data/templates/report.js#L789\n", | |
| " * @returns \"light\" or \"dark\"\n", | |
| " */\n", | |
| "function detectTheme(element) {\n", | |
| " const body = document.querySelector('body');\n", | |
| "\n", | |
| " // Check VSCode theme\n", | |
| " const themeKindAttr = body.getAttribute('data-vscode-theme-kind');\n", | |
| " const themeNameAttr = body.getAttribute('data-vscode-theme-name');\n", | |
| "\n", | |
| " if (themeKindAttr && themeNameAttr) {\n", | |
| " const themeKind = themeKindAttr.toLowerCase();\n", | |
| " const themeName = themeNameAttr.toLowerCase();\n", | |
| "\n", | |
| " if (themeKind.includes(\"dark\") || themeName.includes(\"dark\")) {\n", | |
| " return \"dark\";\n", | |
| " }\n", | |
| " if (themeKind.includes(\"light\") || themeName.includes(\"light\")) {\n", | |
| " return \"light\";\n", | |
| " }\n", | |
| " }\n", | |
| "\n", | |
| " // Check Jupyter theme\n", | |
| " if (body.getAttribute('data-jp-theme-light') === 'false') {\n", | |
| " return 'dark';\n", | |
| " } else if (body.getAttribute('data-jp-theme-light') === 'true') {\n", | |
| " return 'light';\n", | |
| " }\n", | |
| "\n", | |
| " // Guess based on a parent element's color\n", | |
| " const color = window.getComputedStyle(element.parentNode, null).getPropertyValue('color');\n", | |
| " const match = color.match(/^rgb\\s*\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)\\s*$/i);\n", | |
| " if (match) {\n", | |
| " const [r, g, b] = [\n", | |
| " parseFloat(match[1]),\n", | |
| " parseFloat(match[2]),\n", | |
| " parseFloat(match[3])\n", | |
| " ];\n", | |
| "\n", | |
| " // https://en.wikipedia.org/wiki/HSL_and_HSV#Lightness\n", | |
| " const luma = 0.299 * r + 0.587 * g + 0.114 * b;\n", | |
| "\n", | |
| " if (luma > 180) {\n", | |
| " // If the text is very bright we have a dark theme\n", | |
| " return 'dark';\n", | |
| " }\n", | |
| " if (luma < 75) {\n", | |
| " // If the text is very dark we have a light theme\n", | |
| " return 'light';\n", | |
| " }\n", | |
| " // Otherwise fall back to the next heuristic.\n", | |
| " }\n", | |
| "\n", | |
| " // Fallback to system preference\n", | |
| " return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';\n", | |
| "}\n", | |
| "\n", | |
| "\n", | |
| "function forceTheme(elementId) {\n", | |
| " const estimatorElement = document.querySelector(`#${elementId}`);\n", | |
| " if (estimatorElement === null) {\n", | |
| " console.error(`Element with id ${elementId} not found.`);\n", | |
| " } else {\n", | |
| " const theme = detectTheme(estimatorElement);\n", | |
| " estimatorElement.classList.add(theme);\n", | |
| " }\n", | |
| "}\n", | |
| "\n", | |
| "forceTheme('sk-container-id-1');</script></body>" | |
| ], | |
| "text/plain": [ | |
| "MLPClassifier(alpha=0.0, batch_size=69000, hidden_layer_sizes=12,\n", | |
| " learning_rate_init=0.1, max_iter=100, momentum=0.0,\n", | |
| " random_state=42, solver='sgd', verbose=50)" | |
| ] | |
| }, | |
| "execution_count": 53, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Import mlp classifier\n", | |
| "from sklearn.neural_network import MLPClassifier\n", | |
| "\n", | |
| "model_sklearn = MLPClassifier(\n", | |
| " learning_rate=\"constant\",\n", | |
| " learning_rate_init=0.1,\n", | |
| " hidden_layer_sizes=(12), \n", | |
| " activation=\"relu\", \n", | |
| " solver=\"sgd\", \n", | |
| " max_iter=100, \n", | |
| " random_state=42,\n", | |
| " verbose=50,\n", | |
| " batch_size=X_train_scaled.shape[1],\n", | |
| " alpha=0.0,\n", | |
| " momentum=0.0,\n", | |
| ")\n", | |
| "\n", | |
| "model_sklearn.fit(X_train_scaled.T, y_train)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "95b71894", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.853" | |
| ] | |
| }, | |
| "execution_count": 54, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model_sklearn.score(X_test_scaled.T, y_test)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 55, | |
| "id": "6f553d08", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([8, 4, 8, 7, 7], dtype=int16)" | |
| ] | |
| }, | |
| "execution_count": 55, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model_sklearn.predict(X_test_scaled.T[:5])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "id": "0433b554", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'softmax'" | |
| ] | |
| }, | |
| "execution_count": 56, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model_sklearn.out_activation_" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "1df2899b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "3" | |
| ] | |
| }, | |
| "execution_count": 57, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model_sklearn.n_layers_" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 58, | |
| "id": "6394dfe4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Parâmetros numpy\n", | |
| "(12, 784)\n", | |
| "(10, 12)\n", | |
| "(12, 1)\n", | |
| "(10, 1)\n", | |
| "Parâmetros sklearn\n", | |
| "(784, 12)\n", | |
| "(12, 10)\n", | |
| "(12,)\n", | |
| "(10,)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Verificando parâmetros treináveis\n", | |
| "print(\"Parâmetros numpy\")\n", | |
| "for w in [W1, W2, b1, b2]:\n", | |
| " print(w.shape)\n", | |
| "\n", | |
| "print(\"Parâmetros sklearn\")\n", | |
| "for w in model_sklearn.coefs_ + model_sklearn.intercepts_:\n", | |
| " print(w.shape)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "cat", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.14" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment