Created
May 10, 2022 10:36
-
-
Save chetanambi/980dff2636c54086985c3b13706a5b2c to your computer and use it in GitHub Desktop.
OneHotEncoder_vs_get_dummies
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "e905e1ce", | |
| "metadata": {}, | |
| "source": [ | |
| "# OneHotEncoder Vs get_dummies" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "19b73275", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import seaborn as sns\n", | |
| "from sklearn.pipeline import Pipeline\n", | |
| "from sklearn.compose import ColumnTransformer\n", | |
| "from sklearn.preprocessing import OneHotEncoder\n", | |
| "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n", | |
| "from sklearn.linear_model import LinearRegression\n", | |
| "from sklearn.model_selection import train_test_split" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "5e919c45", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "df = sns.load_dataset('tips')\n", | |
| "df = df[['total_bill', 'tip', 'day', 'size']]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "8d2caae5", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>tip</th>\n", | |
| " <th>day</th>\n", | |
| " <th>size</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>16.99</td>\n", | |
| " <td>1.01</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>10.34</td>\n", | |
| " <td>1.66</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>3</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>21.01</td>\n", | |
| " <td>3.50</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>3</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>23.68</td>\n", | |
| " <td>3.31</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>24.59</td>\n", | |
| " <td>3.61</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>4</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill tip day size\n", | |
| "0 16.99 1.01 Sun 2\n", | |
| "1 10.34 1.66 Sun 3\n", | |
| "2 21.01 3.50 Sun 3\n", | |
| "3 23.68 3.31 Sun 2\n", | |
| "4 24.59 3.61 Sun 4" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "df.head(5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e2244d85", | |
| "metadata": {}, | |
| "source": [ | |
| "## Scikit-learn OneHotEncoder" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "0cecdd2c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = df.drop('tip', axis=1)\n", | |
| "y = df['tip']\n", | |
| "\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "c19f3095", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ohe = OneHotEncoder(handle_unknown='ignore', sparse=False, dtype='int')\n", | |
| "ohe.fit(X_train[['day']])\n", | |
| "\n", | |
| "def get_ohe(df):\n", | |
| " temp_df = pd.DataFrame(data=ohe.transform(df[['day']]), columns=ohe.get_feature_names_out())\n", | |
| " df.drop(columns=['day'], axis=1, inplace=True)\n", | |
| " df = pd.concat([df.reset_index(drop=True), temp_df], axis=1)\n", | |
| " return df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "ebea7235", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X_train = get_ohe(X_train)\n", | |
| "X_test = get_ohe(X_test)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "a4c6a11f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>size</th>\n", | |
| " <th>day_Fri</th>\n", | |
| " <th>day_Sat</th>\n", | |
| " <th>day_Sun</th>\n", | |
| " <th>day_Thur</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>26.88</td>\n", | |
| " <td>4</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>32.68</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>17.89</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>20.49</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>48.17</td>\n", | |
| " <td>6</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill size day_Fri day_Sat day_Sun day_Thur\n", | |
| "0 26.88 4 0 0 1 0\n", | |
| "1 32.68 2 0 0 0 1\n", | |
| "2 17.89 2 0 0 1 0\n", | |
| "3 20.49 2 0 1 0 0\n", | |
| "4 48.17 6 0 0 1 0" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_train.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "9c32e5fd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>size</th>\n", | |
| " <th>day_Fri</th>\n", | |
| " <th>day_Sat</th>\n", | |
| " <th>day_Sun</th>\n", | |
| " <th>day_Thur</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>45</td>\n", | |
| " <td>4</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill size day_Fri day_Sat day_Sun day_Thur\n", | |
| "0 25 2 0 0 1 0\n", | |
| "1 45 4 0 0 0 0" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
| "get_ohe(X_test_new)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "148f3359", | |
| "metadata": {}, | |
| "source": [ | |
| "### One hot encoding in pipeline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "49542fbd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = df.drop('tip', axis=1)\n", | |
| "y = df['tip']\n", | |
| "\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "1e620741", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "numeric_preprocessor = Pipeline(steps=[\n", | |
| " (\"scaler\", MinMaxScaler()) \n", | |
| "])\n", | |
| "\n", | |
| "categorical_preprocessor = Pipeline(steps=[ \n", | |
| " (\"onehot\", OneHotEncoder(handle_unknown=\"ignore\")) \n", | |
| "])\n", | |
| "\n", | |
| "preprocessor = ColumnTransformer([\n", | |
| " (\"categorical\", categorical_preprocessor, [\"day\"]),\n", | |
| " (\"numerical\", numeric_preprocessor, [\"total_bill\", \"size\"])\n", | |
| "])\n", | |
| "\n", | |
| "pipe = Pipeline(steps=[\n", | |
| " (\"preprocessor\", preprocessor), \n", | |
| " (\"classifier\", LinearRegression())\n", | |
| "])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "3f9e210c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.5706168878130049" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pipe.fit(X_train, y_train)\n", | |
| "pipe.score(X_test, y_test)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "dea80c6d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>day</th>\n", | |
| " <th>size</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>45</td>\n", | |
| " <td>Mon</td>\n", | |
| " <td>4</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill day size\n", | |
| "0 25 Sun 2\n", | |
| "1 45 Mon 4" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
| "X_test_new" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "7770b3a2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([3.27325822, 5.436413 ])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pipe.predict(X_test_new)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ddb6ac19", | |
| "metadata": {}, | |
| "source": [ | |
| "## Pandas get_dummies " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "26e6c5eb", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = df.drop('tip', axis=1)\n", | |
| "y = df['tip']\n", | |
| "\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "248ff42a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X_train = pd.get_dummies(X_train, columns=['day'])\n", | |
| "X_test = pd.get_dummies(X_test, columns=['day'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "0fd105d0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>size</th>\n", | |
| " <th>day_Thur</th>\n", | |
| " <th>day_Fri</th>\n", | |
| " <th>day_Sat</th>\n", | |
| " <th>day_Sun</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>7</th>\n", | |
| " <td>26.88</td>\n", | |
| " <td>4</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>83</th>\n", | |
| " <td>32.68</td>\n", | |
| " <td>2</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>176</th>\n", | |
| " <td>17.89</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>106</th>\n", | |
| " <td>20.49</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>156</th>\n", | |
| " <td>48.17</td>\n", | |
| " <td>6</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill size day_Thur day_Fri day_Sat day_Sun\n", | |
| "7 26.88 4 0 0 0 1\n", | |
| "83 32.68 2 1 0 0 0\n", | |
| "176 17.89 2 0 0 0 1\n", | |
| "106 20.49 2 0 0 1 0\n", | |
| "156 48.17 6 0 0 0 1" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_train.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "0c048ce6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "cols = X_test.columns.tolist()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "015e66a3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>size</th>\n", | |
| " <th>day_Mon</th>\n", | |
| " <th>day_Sun</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>45</td>\n", | |
| " <td>4</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill size day_Mon day_Sun\n", | |
| "0 25 2 0 1\n", | |
| "1 45 4 1 0" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
| "pd.get_dummies(X_test_new, columns=['day']).head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "75cd1330", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>day</th>\n", | |
| " <th>size</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>45</td>\n", | |
| " <td>Mon</td>\n", | |
| " <td>4</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill day size\n", | |
| "0 25 Sun 2\n", | |
| "1 45 Mon 4" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test_new" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "6abea7c5", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>size</th>\n", | |
| " <th>day_Thur</th>\n", | |
| " <th>day_Fri</th>\n", | |
| " <th>day_Sat</th>\n", | |
| " <th>day_Sun</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>2</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>45</td>\n", | |
| " <td>4</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill size day_Thur day_Fri day_Sat day_Sun\n", | |
| "0 25 2 0.0 0.0 0.0 1\n", | |
| "1 45 4 0.0 0.0 0.0 0" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test_new = pd.get_dummies(X_test_new, columns=['day'])\n", | |
| "X_test_new.reindex(columns=cols).fillna(0) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "8c3be6c6", | |
| "metadata": {}, | |
| "source": [ | |
| "### get_dummies in Pipeline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "72a6c0fc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = df.drop('tip', axis=1)\n", | |
| "y = df['tip']\n", | |
| "\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "391df006", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "cols = [\"total_bill\", \"size\", \"day_Fri\", \"day_Sat\", \"day_Sun\", \"day_Thur\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "fafc567b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.base import BaseEstimator, TransformerMixin\n", | |
| "\n", | |
| "class PreprocessorTransformer(BaseEstimator, TransformerMixin):\n", | |
| " def __init__(self, cols):\n", | |
| " self.cols = cols\n", | |
| " \n", | |
| " def fit(self, X, y = None):\n", | |
| " return self\n", | |
| " \n", | |
| " def transform(self, X, y = None):\n", | |
| " X = pd.get_dummies(X, columns=['day'])\n", | |
| " X = X.reindex(columns=self.cols).fillna(0) \n", | |
| " return X[self.cols]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "id": "ae0ec1a3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "preprocessor = Pipeline(steps=[ \n", | |
| " (\"preprocessor\", PreprocessorTransformer(cols)) \n", | |
| "])\n", | |
| "\n", | |
| "pipe = Pipeline(steps=[\n", | |
| " (\"preprocessor\", preprocessor), \n", | |
| " (\"classifier\", LinearRegression())\n", | |
| "])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "id": "4e363485", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "pipe.fit(X_train, y_train);" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "id": "612e202f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.5706168878130053" | |
| ] | |
| }, | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pipe.score(X_test, y_test)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "id": "05518bcc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>total_bill</th>\n", | |
| " <th>day</th>\n", | |
| " <th>size</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>25</td>\n", | |
| " <td>Sun</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>45</td>\n", | |
| " <td>Mon</td>\n", | |
| " <td>4</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " total_bill day size\n", | |
| "0 25 Sun 2\n", | |
| "1 45 Mon 4" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_test_new = pd.DataFrame( {'total_bill': [25, 45], 'day': ['Sun', 'Mon'], 'size': [2, 4]} )\n", | |
| "X_test_new" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "e475f927", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([3.27325822, 5.436413 ])" | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pipe.predict(X_test_new)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "a49f3eaf", | |
| "metadata": {}, | |
| "source": [ | |
| "## Example" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "2930f647", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "data = {\"Airline\": [\"American Airlines\", \"Delta Air Lines\", \"United Airlines\"]}\n", | |
| "df = pd.DataFrame(data=data)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "id": "f0fd2cd9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>Airline_American Airlines</th>\n", | |
| " <th>Airline_Delta Air Lines</th>\n", | |
| " <th>Airline_United Airlines</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " Airline_American Airlines Airline_Delta Air Lines Airline_United Airlines\n", | |
| "0 1 0 0\n", | |
| "1 0 1 0\n", | |
| "2 0 0 1" | |
| ] | |
| }, | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pd.get_dummies(df)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "id": "1f34ab08", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead tr th {\n", | |
| " text-align: left;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr>\n", | |
| " <th></th>\n", | |
| " <th>American Airlines</th>\n", | |
| " <th>Delta Air Lines</th>\n", | |
| " <th>United Airlines</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " American Airlines Delta Air Lines United Airlines\n", | |
| "0 1 0 0\n", | |
| "1 0 1 0\n", | |
| "2 0 0 1" | |
| ] | |
| }, | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ohe = OneHotEncoder(sparse=False)\n", | |
| "pd.DataFrame(ohe.fit_transform(df), columns=ohe.categories_, dtype='int8')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment