Created
November 14, 2020 00:05
-
-
Save jshirius/90ac6c9b4881bf4c9b96ba4d0007496a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "PCA(主成分分析)によるデータ水増しテスト.ipynb", | |
| "provenance": [], | |
| "collapsed_sections": [] | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "HG4e2sPbf_F9" | |
| }, | |
| "source": [ | |
| "# 主成分分析(PCA)を使ってデータの水増し(行を増やす)ができるかの検証\n", | |
| "- 結論は、可能と判断\n", | |
| "- ポイントは以下の通り\n", | |
| " - PCAの前に標準化すること\n", | |
| " - 圧縮のパラメータ(n_components)は、「元のカラム数 - 1」が良い\n", | |
| "- 実際にkaggleで、少ないラベルのところに適用したらスコアの上昇を確認できた" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ilh3dyyfWQw5" | |
| }, | |
| "source": [ | |
| "from sklearn.datasets import load_iris\n", | |
| "import pandas as pd\n", | |
| "from sklearn.preprocessing import StandardScaler\n", | |
| "from sklearn.decomposition import PCA" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "pMXjOh1dffh7" | |
| }, | |
| "source": [ | |
| "# データの準備\n", | |
| "- 今回は定番のirisを使う\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "G2fX6RBxXmQa" | |
| }, | |
| "source": [ | |
| "iris = load_iris()\n", | |
| "df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", | |
| "df['target'] = iris.target\n", | |
| "\n", | |
| "#標準化したいカラム取り出す\n", | |
| "#今回は4カラム分\n", | |
| "X = df.iloc[:, 0:4]" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "o950e99aZKFD", | |
| "outputId": "3968cb0c-a662-4c48-90c6-bb2b33486f53", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 419 | |
| } | |
| }, | |
| "source": [ | |
| "X" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>sepal length (cm)</th>\n", | |
| " <th>sepal width (cm)</th>\n", | |
| " <th>petal length (cm)</th>\n", | |
| " <th>petal width (cm)</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>5.1</td>\n", | |
| " <td>3.5</td>\n", | |
| " <td>1.4</td>\n", | |
| " <td>0.2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>4.9</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1.4</td>\n", | |
| " <td>0.2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>4.7</td>\n", | |
| " <td>3.2</td>\n", | |
| " <td>1.3</td>\n", | |
| " <td>0.2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>4.6</td>\n", | |
| " <td>3.1</td>\n", | |
| " <td>1.5</td>\n", | |
| " <td>0.2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>5.0</td>\n", | |
| " <td>3.6</td>\n", | |
| " <td>1.4</td>\n", | |
| " <td>0.2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>145</th>\n", | |
| " <td>6.7</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>5.2</td>\n", | |
| " <td>2.3</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>146</th>\n", | |
| " <td>6.3</td>\n", | |
| " <td>2.5</td>\n", | |
| " <td>5.0</td>\n", | |
| " <td>1.9</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>147</th>\n", | |
| " <td>6.5</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>5.2</td>\n", | |
| " <td>2.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>148</th>\n", | |
| " <td>6.2</td>\n", | |
| " <td>3.4</td>\n", | |
| " <td>5.4</td>\n", | |
| " <td>2.3</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>149</th>\n", | |
| " <td>5.9</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>5.1</td>\n", | |
| " <td>1.8</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>150 rows × 4 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", | |
| "0 5.1 3.5 1.4 0.2\n", | |
| "1 4.9 3.0 1.4 0.2\n", | |
| "2 4.7 3.2 1.3 0.2\n", | |
| "3 4.6 3.1 1.5 0.2\n", | |
| "4 5.0 3.6 1.4 0.2\n", | |
| ".. ... ... ... ...\n", | |
| "145 6.7 3.0 5.2 2.3\n", | |
| "146 6.3 2.5 5.0 1.9\n", | |
| "147 6.5 3.0 5.2 2.0\n", | |
| "148 6.2 3.4 5.4 2.3\n", | |
| "149 5.9 3.0 5.1 1.8\n", | |
| "\n", | |
| "[150 rows x 4 columns]" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "6Xdm-0Tqfp7z" | |
| }, | |
| "source": [ | |
| "# 標準化のテスト\n", | |
| "- データを標準化して、さらにもとの値に戻るかの確認をしておく" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "34UjOW3fXxOE", | |
| "outputId": "40415f4a-ddb9-4e8b-d3f8-5364ad7381b6", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| } | |
| }, | |
| "source": [ | |
| "sc = StandardScaler()\n", | |
| "result = sc.fit(X).transform(X) #標準化\n", | |
| "#sc.inverse_transform(result) #標準化を元に戻す\n", | |
| "result*sc.scale_ + sc.mean_ #標準化を元に戻す()\n", | |
| "\n", | |
| "#標準化前の値と、標準化から復元したあたいは一致したことがわかった" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "array([[5.1, 3.5, 1.4, 0.2],\n", | |
| " [4.9, 3. , 1.4, 0.2],\n", | |
| " [4.7, 3.2, 1.3, 0.2],\n", | |
| " [4.6, 3.1, 1.5, 0.2],\n", | |
| " [5. , 3.6, 1.4, 0.2],\n", | |
| " [5.4, 3.9, 1.7, 0.4],\n", | |
| " [4.6, 3.4, 1.4, 0.3],\n", | |
| " [5. , 3.4, 1.5, 0.2],\n", | |
| " [4.4, 2.9, 1.4, 0.2],\n", | |
| " [4.9, 3.1, 1.5, 0.1],\n", | |
| " [5.4, 3.7, 1.5, 0.2],\n", | |
| " [4.8, 3.4, 1.6, 0.2],\n", | |
| " [4.8, 3. , 1.4, 0.1],\n", | |
| " [4.3, 3. , 1.1, 0.1],\n", | |
| " [5.8, 4. , 1.2, 0.2],\n", | |
| " [5.7, 4.4, 1.5, 0.4],\n", | |
| " [5.4, 3.9, 1.3, 0.4],\n", | |
| " [5.1, 3.5, 1.4, 0.3],\n", | |
| " [5.7, 3.8, 1.7, 0.3],\n", | |
| " [5.1, 3.8, 1.5, 0.3],\n", | |
| " [5.4, 3.4, 1.7, 0.2],\n", | |
| " [5.1, 3.7, 1.5, 0.4],\n", | |
| " [4.6, 3.6, 1. , 0.2],\n", | |
| " [5.1, 3.3, 1.7, 0.5],\n", | |
| " [4.8, 3.4, 1.9, 0.2],\n", | |
| " [5. , 3. , 1.6, 0.2],\n", | |
| " [5. , 3.4, 1.6, 0.4],\n", | |
| " [5.2, 3.5, 1.5, 0.2],\n", | |
| " [5.2, 3.4, 1.4, 0.2],\n", | |
| " [4.7, 3.2, 1.6, 0.2],\n", | |
| " [4.8, 3.1, 1.6, 0.2],\n", | |
| " [5.4, 3.4, 1.5, 0.4],\n", | |
| " [5.2, 4.1, 1.5, 0.1],\n", | |
| " [5.5, 4.2, 1.4, 0.2],\n", | |
| " [4.9, 3.1, 1.5, 0.2],\n", | |
| " [5. , 3.2, 1.2, 0.2],\n", | |
| " [5.5, 3.5, 1.3, 0.2],\n", | |
| " [4.9, 3.6, 1.4, 0.1],\n", | |
| " [4.4, 3. , 1.3, 0.2],\n", | |
| " [5.1, 3.4, 1.5, 0.2],\n", | |
| " [5. , 3.5, 1.3, 0.3],\n", | |
| " [4.5, 2.3, 1.3, 0.3],\n", | |
| " [4.4, 3.2, 1.3, 0.2],\n", | |
| " [5. , 3.5, 1.6, 0.6],\n", | |
| " [5.1, 3.8, 1.9, 0.4],\n", | |
| " [4.8, 3. , 1.4, 0.3],\n", | |
| " [5.1, 3.8, 1.6, 0.2],\n", | |
| " [4.6, 3.2, 1.4, 0.2],\n", | |
| " [5.3, 3.7, 1.5, 0.2],\n", | |
| " [5. , 3.3, 1.4, 0.2],\n", | |
| " [7. , 3.2, 4.7, 1.4],\n", | |
| " [6.4, 3.2, 4.5, 1.5],\n", | |
| " [6.9, 3.1, 4.9, 1.5],\n", | |
| " [5.5, 2.3, 4. , 1.3],\n", | |
| " [6.5, 2.8, 4.6, 1.5],\n", | |
| " [5.7, 2.8, 4.5, 1.3],\n", | |
| " [6.3, 3.3, 4.7, 1.6],\n", | |
| " [4.9, 2.4, 3.3, 1. ],\n", | |
| " [6.6, 2.9, 4.6, 1.3],\n", | |
| " [5.2, 2.7, 3.9, 1.4],\n", | |
| " [5. , 2. , 3.5, 1. ],\n", | |
| " [5.9, 3. , 4.2, 1.5],\n", | |
| " [6. , 2.2, 4. , 1. ],\n", | |
| " [6.1, 2.9, 4.7, 1.4],\n", | |
| " [5.6, 2.9, 3.6, 1.3],\n", | |
| " [6.7, 3.1, 4.4, 1.4],\n", | |
| " [5.6, 3. , 4.5, 1.5],\n", | |
| " [5.8, 2.7, 4.1, 1. ],\n", | |
| " [6.2, 2.2, 4.5, 1.5],\n", | |
| " [5.6, 2.5, 3.9, 1.1],\n", | |
| " [5.9, 3.2, 4.8, 1.8],\n", | |
| " [6.1, 2.8, 4. , 1.3],\n", | |
| " [6.3, 2.5, 4.9, 1.5],\n", | |
| " [6.1, 2.8, 4.7, 1.2],\n", | |
| " [6.4, 2.9, 4.3, 1.3],\n", | |
| " [6.6, 3. , 4.4, 1.4],\n", | |
| " [6.8, 2.8, 4.8, 1.4],\n", | |
| " [6.7, 3. , 5. , 1.7],\n", | |
| " [6. , 2.9, 4.5, 1.5],\n", | |
| " [5.7, 2.6, 3.5, 1. ],\n", | |
| " [5.5, 2.4, 3.8, 1.1],\n", | |
| " [5.5, 2.4, 3.7, 1. ],\n", | |
| " [5.8, 2.7, 3.9, 1.2],\n", | |
| " [6. , 2.7, 5.1, 1.6],\n", | |
| " [5.4, 3. , 4.5, 1.5],\n", | |
| " [6. , 3.4, 4.5, 1.6],\n", | |
| " [6.7, 3.1, 4.7, 1.5],\n", | |
| " [6.3, 2.3, 4.4, 1.3],\n", | |
| " [5.6, 3. , 4.1, 1.3],\n", | |
| " [5.5, 2.5, 4. , 1.3],\n", | |
| " [5.5, 2.6, 4.4, 1.2],\n", | |
| " [6.1, 3. , 4.6, 1.4],\n", | |
| " [5.8, 2.6, 4. , 1.2],\n", | |
| " [5. , 2.3, 3.3, 1. ],\n", | |
| " [5.6, 2.7, 4.2, 1.3],\n", | |
| " [5.7, 3. , 4.2, 1.2],\n", | |
| " [5.7, 2.9, 4.2, 1.3],\n", | |
| " [6.2, 2.9, 4.3, 1.3],\n", | |
| " [5.1, 2.5, 3. , 1.1],\n", | |
| " [5.7, 2.8, 4.1, 1.3],\n", | |
| " [6.3, 3.3, 6. , 2.5],\n", | |
| " [5.8, 2.7, 5.1, 1.9],\n", | |
| " [7.1, 3. , 5.9, 2.1],\n", | |
| " [6.3, 2.9, 5.6, 1.8],\n", | |
| " [6.5, 3. , 5.8, 2.2],\n", | |
| " [7.6, 3. , 6.6, 2.1],\n", | |
| " [4.9, 2.5, 4.5, 1.7],\n", | |
| " [7.3, 2.9, 6.3, 1.8],\n", | |
| " [6.7, 2.5, 5.8, 1.8],\n", | |
| " [7.2, 3.6, 6.1, 2.5],\n", | |
| " [6.5, 3.2, 5.1, 2. ],\n", | |
| " [6.4, 2.7, 5.3, 1.9],\n", | |
| " [6.8, 3. , 5.5, 2.1],\n", | |
| " [5.7, 2.5, 5. , 2. ],\n", | |
| " [5.8, 2.8, 5.1, 2.4],\n", | |
| " [6.4, 3.2, 5.3, 2.3],\n", | |
| " [6.5, 3. , 5.5, 1.8],\n", | |
| " [7.7, 3.8, 6.7, 2.2],\n", | |
| " [7.7, 2.6, 6.9, 2.3],\n", | |
| " [6. , 2.2, 5. , 1.5],\n", | |
| " [6.9, 3.2, 5.7, 2.3],\n", | |
| " [5.6, 2.8, 4.9, 2. ],\n", | |
| " [7.7, 2.8, 6.7, 2. ],\n", | |
| " [6.3, 2.7, 4.9, 1.8],\n", | |
| " [6.7, 3.3, 5.7, 2.1],\n", | |
| " [7.2, 3.2, 6. , 1.8],\n", | |
| " [6.2, 2.8, 4.8, 1.8],\n", | |
| " [6.1, 3. , 4.9, 1.8],\n", | |
| " [6.4, 2.8, 5.6, 2.1],\n", | |
| " [7.2, 3. , 5.8, 1.6],\n", | |
| " [7.4, 2.8, 6.1, 1.9],\n", | |
| " [7.9, 3.8, 6.4, 2. ],\n", | |
| " [6.4, 2.8, 5.6, 2.2],\n", | |
| " [6.3, 2.8, 5.1, 1.5],\n", | |
| " [6.1, 2.6, 5.6, 1.4],\n", | |
| " [7.7, 3. , 6.1, 2.3],\n", | |
| " [6.3, 3.4, 5.6, 2.4],\n", | |
| " [6.4, 3.1, 5.5, 1.8],\n", | |
| " [6. , 3. , 4.8, 1.8],\n", | |
| " [6.9, 3.1, 5.4, 2.1],\n", | |
| " [6.7, 3.1, 5.6, 2.4],\n", | |
| " [6.9, 3.1, 5.1, 2.3],\n", | |
| " [5.8, 2.7, 5.1, 1.9],\n", | |
| " [6.8, 3.2, 5.9, 2.3],\n", | |
| " [6.7, 3.3, 5.7, 2.5],\n", | |
| " [6.7, 3. , 5.2, 2.3],\n", | |
| " [6.3, 2.5, 5. , 1.9],\n", | |
| " [6.5, 3. , 5.2, 2. ],\n", | |
| " [6.2, 3.4, 5.4, 2.3],\n", | |
| " [5.9, 3. , 5.1, 1.8]])" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 5 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "nRViW7rxXx1a", | |
| "outputId": "d2627b47-e689-400e-ad77-dd17c36c88fa", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| } | |
| }, | |
| "source": [ | |
| " sc.mean_, sc.scale_" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(array([5.84333333, 3.05733333, 3.758 , 1.19933333]),\n", | |
| " array([0.82530129, 0.43441097, 1.75940407, 0.75969263]))" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 6 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "zxRI6WKXhPth" | |
| }, | |
| "source": [ | |
| "# 標準化なしのPCAによるデータの水増しが出来るか確認\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "IiKrw7NwZ32H", | |
| "outputId": "b42e66e0-b436-447f-f9f8-6f0c581c7104", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| } | |
| }, | |
| "source": [ | |
| "#PCAによるデータ水増し実験\n", | |
| "n_comp = 3#(元のカラム数 - 1)\n", | |
| "\n", | |
| "#圧縮\n", | |
| "pca = PCA(n_components=n_comp, random_state=42)\n", | |
| "pca_res = pca.fit_transform(X)\n", | |
| "\n", | |
| "#圧縮をもとに戻す\n", | |
| "restore_list = pca.inverse_transform(pca_res)\n", | |
| "\n", | |
| "\n", | |
| "#1列目の差分を取る\n", | |
| "diff = 0\n", | |
| "for index, i in enumerate(X.values.tolist()):\n", | |
| " diff += abs(i[0] - restore_list[index][0])\n", | |
| " #print(i[0] )\n", | |
| "\n", | |
| "#1列目の元データと圧縮後の復元の差分を確認する\n", | |
| "print(diff)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "5.482752051480244\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "WgMQiyothwG0" | |
| }, | |
| "source": [ | |
| "# 標準化ありのPCAによるデータの水増しが出来るか確認\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "d9R96PqGbyHs", | |
| "outputId": "313ab784-c825-4fb4-8fe1-013d7daf2eed", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| } | |
| }, | |
| "source": [ | |
| "#PCAによるデータ水増し(標準化あり)\n", | |
| "n_comp = 3#(元のカラム数 - 1)\n", | |
| "\n", | |
| "sc = StandardScaler()\n", | |
| "X_sc = sc.fit(X).transform(X) #標準化\n", | |
| "\n", | |
| "#PCAによる圧縮\n", | |
| "pca = PCA(n_components=n_comp, random_state=42)\n", | |
| "pca_res = pca.fit_transform(X_sc)\n", | |
| "#PCAの復元化\n", | |
| "restore_list = pca.inverse_transform(pca_res) \n", | |
| "\n", | |
| "#標準化を元に戻す\n", | |
| "restore_list = restore_list*sc.scale_ + sc.mean_ \n", | |
| "\n", | |
| "#1列目の差分を取る\n", | |
| "diff = 0\n", | |
| "for index, i in enumerate(X.values.tolist()):\n", | |
| " diff += abs(i[0] - restore_list[index][0])\n", | |
| " #print(i[0] )\n", | |
| "print(diff)\n", | |
| "\n", | |
| "#実際にデータを行単位で水増しするときは、「restore_list」を、元データに追加することになる。\n" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "3.5309526623433074\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "mG2orkskiWCD" | |
| }, | |
| "source": [ | |
| "これまでの結果から、標準化してからPCA圧縮したほうが、\n", | |
| "より元のデータに近い形でデータの水増しが出来ることがわかった。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "i4K_0DApefVK", | |
| "outputId": "92c8e6ca-e1e6-4b34-b314-0cf4ed5f964a", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| } | |
| }, | |
| "source": [ | |
| "#これまでの結果を一般化するため関数にまとめる\n", | |
| "def data_augmentation(X, n_comp):\n", | |
| " #Xは元データ\n", | |
| " #n_compはn_componentsの圧縮パラメータ\n", | |
| "\n", | |
| " sc = StandardScaler()\n", | |
| " X_sc = sc.fit(X).transform(X) #標準化\n", | |
| "\n", | |
| " #PCAによる圧縮\n", | |
| " pca = PCA(n_components=n_comp, random_state=42)\n", | |
| " pca_res = pca.fit_transform(X_sc)\n", | |
| "\n", | |
| " #PCAの復元化\n", | |
| " restore_list = pca.inverse_transform(pca_res) \n", | |
| "\n", | |
| " #標準化を元に戻す\n", | |
| " restore_list = restore_list*sc.scale_ + sc.mean_ \n", | |
| "\n", | |
| " #1列目の差分も評価用に出力\n", | |
| " diff = 0\n", | |
| " for index, i in enumerate(X):\n", | |
| " diff += abs(i[0] - restore_list[index][0])\n", | |
| " #print(i[0] )\n", | |
| "\n", | |
| " return restore_list, diff\n", | |
| "\n", | |
| "aug_data, diff = data_augmentation(X.values.tolist(), 3)\n", | |
| "diff" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "3.5309526623433074" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 9 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "LkftBNwelK6F", | |
| "outputId": "cd6ca611-d352-401c-95c2-d842584709c1", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 419 | |
| } | |
| }, | |
| "source": [ | |
| "#pandasでデータを渡して、水増し対象の列のみ水増し、それ以外は、データをコピーして返す\n", | |
| "def pd_augmentation(df, aug_col, n_comp):\n", | |
| " #dfのコピー\n", | |
| " df_cp = df.copy()\n", | |
| " df_cp = df_cp.reset_index(drop=True)\n", | |
| "\n", | |
| " #水増し該当列のみ水増し対応\n", | |
| " aug_data, diff = data_augmentation(df_cp[aug_col].values.tolist(), n_comp)\n", | |
| "\n", | |
| " #再びdfに変換する\n", | |
| " aug_data_df = pd.DataFrame(aug_data, columns=aug_col)\n", | |
| "\n", | |
| " #水増し側のデータに上書きする\n", | |
| " for col in aug_col:\n", | |
| " df_cp[col] = aug_data_df[col]\n", | |
| "\n", | |
| " return df_cp\n", | |
| "\n", | |
| "\n", | |
| "aug_df = pd_augmentation(df, iris.feature_names, 3)\n", | |
| "aug_df" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>sepal length (cm)</th>\n", | |
| " <th>sepal width (cm)</th>\n", | |
| " <th>petal length (cm)</th>\n", | |
| " <th>petal width (cm)</th>\n", | |
| " <th>target</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>5.094788</td>\n", | |
| " <td>3.501297</td>\n", | |
| " <td>1.434079</td>\n", | |
| " <td>0.190387</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>4.877788</td>\n", | |
| " <td>3.005527</td>\n", | |
| " <td>1.545247</td>\n", | |
| " <td>0.159027</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>4.693881</td>\n", | |
| " <td>3.201523</td>\n", | |
| " <td>1.340014</td>\n", | |
| " <td>0.188712</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>4.614223</td>\n", | |
| " <td>3.096461</td>\n", | |
| " <td>1.406998</td>\n", | |
| " <td>0.226235</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>5.007746</td>\n", | |
| " <td>3.598073</td>\n", | |
| " <td>1.349346</td>\n", | |
| " <td>0.214289</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>145</th>\n", | |
| " <td>6.616061</td>\n", | |
| " <td>3.020885</td>\n", | |
| " <td>5.748881</td>\n", | |
| " <td>2.145164</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>146</th>\n", | |
| " <td>6.252518</td>\n", | |
| " <td>2.511814</td>\n", | |
| " <td>5.310487</td>\n", | |
| " <td>1.812414</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>147</th>\n", | |
| " <td>6.474302</td>\n", | |
| " <td>3.006394</td>\n", | |
| " <td>5.368040</td>\n", | |
| " <td>1.952597</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>148</th>\n", | |
| " <td>6.194366</td>\n", | |
| " <td>3.401402</td>\n", | |
| " <td>5.436843</td>\n", | |
| " <td>2.289607</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>149</th>\n", | |
| " <td>5.935166</td>\n", | |
| " <td>2.991250</td>\n", | |
| " <td>4.870048</td>\n", | |
| " <td>1.864868</td>\n", | |
| " <td>2</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>150 rows × 5 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " sepal length (cm) sepal width (cm) ... petal width (cm) target\n", | |
| "0 5.094788 3.501297 ... 0.190387 0\n", | |
| "1 4.877788 3.005527 ... 0.159027 0\n", | |
| "2 4.693881 3.201523 ... 0.188712 0\n", | |
| "3 4.614223 3.096461 ... 0.226235 0\n", | |
| "4 5.007746 3.598073 ... 0.214289 0\n", | |
| ".. ... ... ... ... ...\n", | |
| "145 6.616061 3.020885 ... 2.145164 2\n", | |
| "146 6.252518 2.511814 ... 1.812414 2\n", | |
| "147 6.474302 3.006394 ... 1.952597 2\n", | |
| "148 6.194366 3.401402 ... 2.289607 2\n", | |
| "149 5.935166 2.991250 ... 1.864868 2\n", | |
| "\n", | |
| "[150 rows x 5 columns]" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 20 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "U_DCipcACmHK" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment