Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save calvinmccarter/2cce9f6d06adbd48efebf32d788c84b9 to your computer and use it in GitHub Desktop.

Select an option

Save calvinmccarter/2cce9f6d06adbd48efebf32d788c84b9 to your computer and use it in GitHub Desktop.
reproducing Treeffuser m5 no tuning
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['/Users/cmccarter/miniconda3/envs/maskingtrees/lib/python39.zip', '/Users/cmccarter/miniconda3/envs/maskingtrees/lib/python3.9', '/Users/cmccarter/miniconda3/envs/maskingtrees/lib/python3.9/lib-dynload', '', '/Users/cmccarter/.local/lib/python3.9/site-packages', '/Users/cmccarter/miniconda3/envs/maskingtrees/lib/python3.9/site-packages', '/Users/cmccarter/sandbox/treeffuser/testbed/src', '/Users/cmccarter/sandbox/treeffuser/src', '../src', '../src']\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append(\"../src\")\n",
"print(sys.path)\n",
"# autoreload\n",
"import lightgbm as lgb\n",
"\n",
"#from testbed.models.quantile_regression import QuantileRegression\n",
"from testbed.models.treeffuser import Treeffuser\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from pathlib import Path\n",
"from tqdm import tqdm\n",
"import pickle as pkl\n",
"\n",
"\n",
"\n",
"#from testbed.models.ngboost import NGBoostGaussian, NGBoostMixtureGaussian, NGBoostPoisson\n",
"from testbed.models.base_model import BayesOptProbabilisticModel\n",
"\n",
"\n",
"from functools import partial\n",
"\n",
"from jaxtyping import Float, Array\n",
"from typing import List, Callable\n",
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from testbed.metrics.log_likelihood import LogLikelihoodFromSamplesMetric\n",
"from testbed.metrics.crps import CRPS\n",
"from testbed.metrics.accuracy import AccuracyMetric\n",
"\n",
"\n",
"path = \"/Users/cmccarter/sandbox/treeffuser/testbed/src/testbed/data/m5\"\n",
"\n",
"# load autoreload extension\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# These are config variables\n",
"\n",
"PROCESS_FROM_SCRATCH = True\n",
"USE_SUBSET = True\n",
"CONTEXT_LENGTH = 20\n",
"RUN_DEPRECATED = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"columns of sell_prices_df:\n",
"store_id\n",
"item_id\n",
"wm_yr_wk\n",
"sell_price\n",
"\n",
"columns of sales_train_validation_df:\n",
"id\n",
"item_id\n",
"dept_id\n",
"cat_id\n",
"store_id\n",
"state_id\n",
"\n",
"columns of calendar_df:\n",
"date\n",
"wm_yr_wk\n",
"weekday\n",
"wday\n",
"month\n",
"year\n",
"d\n",
"event_name_1\n",
"event_type_1\n",
"event_name_2\n",
"event_type_2\n",
"snap_CA\n",
"snap_TX\n",
"snap_WI\n",
"number of zeros in sales_train_validation_df: 39777094\n"
]
}
],
"source": [
"# READ IN DATA\n",
"\n",
"sell_prices_df = pd.read_csv(Path(path) / \"sell_prices.csv\")\n",
"sales_train_validation_df = pd.read_csv(Path(path) / \"sales_train_validation.csv\")\n",
"calendar_df = pd.read_csv(Path(path) / \"calendar.csv\")\n",
"\n",
"print(\"\\ncolumns of sell_prices_df:\")\n",
"[print(col) for col in sell_prices_df.columns]\n",
"print(\"\\ncolumns of sales_train_validation_df:\")\n",
"[print(col) for col in sales_train_validation_df.columns if not col.startswith(\"d_\")]\n",
"print(\"\\ncolumns of calendar_df:\") # ommit d_1, d_2, ..., d_1913\n",
"[print(col) for col in calendar_df.columns if not col.startswith(\"d_\")]\n",
"\n",
"\"\"\n",
"\n",
"# print number of zeros\n",
"print(\"number of zeros in sales_train_validation_df: \", (sales_train_validation_df == 0).sum().sum())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of zeros in sales_train_validation_df: 39777094 out of 58327370 entries\n",
"percentage of zeros in sales_train_validation_df: 68.20%\n"
]
}
],
"source": [
"#num_zeros = sales_train_validation_df.isin([0]).sum().sum()\n",
"#total_entries = sales_train_validation_df.\n",
"\n",
"items_sold_cols = sales_train_validation_df.columns[sales_train_validation_df.columns.str.startswith(\"d_\")]\n",
"num_zeros = (sales_train_validation_df[items_sold_cols] == 0).sum().sum()\n",
"total_entries = sales_train_validation_df[items_sold_cols].shape[0] * sales_train_validation_df[items_sold_cols].shape[1]\n",
"\n",
"print(f\"number of zeros in sales_train_validation_df: {num_zeros} out of {total_entries} entries\")\n",
"print(f\"percentage of zeros in sales_train_validation_df: {num_zeros / total_entries * 100:.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# add explicit columns for the day, month, year for ease of processing\n",
"calendar_df[\"date\"] = pd.to_datetime(calendar_df[\"date\"])\n",
"calendar_df[\"day\"] = calendar_df[\"date\"].dt.day\n",
"calendar_df[\"month\"] = calendar_df[\"date\"].dt.month\n",
"calendar_df[\"year\"] = calendar_df[\"date\"].dt.year\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Brief snapshots of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>date</th>\n",
" <th>wm_yr_wk</th>\n",
" <th>weekday</th>\n",
" <th>wday</th>\n",
" <th>month</th>\n",
" <th>year</th>\n",
" <th>d</th>\n",
" <th>event_name_1</th>\n",
" <th>event_type_1</th>\n",
" <th>event_name_2</th>\n",
" <th>event_type_2</th>\n",
" <th>snap_CA</th>\n",
" <th>snap_TX</th>\n",
" <th>snap_WI</th>\n",
" <th>day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2011-01-29</td>\n",
" <td>11101</td>\n",
" <td>Saturday</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2011</td>\n",
" <td>d_1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2011-01-30</td>\n",
" <td>11101</td>\n",
" <td>Sunday</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>2011</td>\n",
" <td>d_2</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2011-01-31</td>\n",
" <td>11101</td>\n",
" <td>Monday</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>2011</td>\n",
" <td>d_3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>31</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2011-02-01</td>\n",
" <td>11101</td>\n",
" <td>Tuesday</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>2011</td>\n",
" <td>d_4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2011-02-02</td>\n",
" <td>11101</td>\n",
" <td>Wednesday</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>2011</td>\n",
" <td>d_5</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" date wm_yr_wk weekday wday month year d event_name_1 \\\n",
"0 2011-01-29 11101 Saturday 1 1 2011 d_1 NaN \n",
"1 2011-01-30 11101 Sunday 2 1 2011 d_2 NaN \n",
"2 2011-01-31 11101 Monday 3 1 2011 d_3 NaN \n",
"3 2011-02-01 11101 Tuesday 4 2 2011 d_4 NaN \n",
"4 2011-02-02 11101 Wednesday 5 2 2011 d_5 NaN \n",
"\n",
" event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI day \n",
"0 NaN NaN NaN 0 0 0 29 \n",
"1 NaN NaN NaN 0 0 0 30 \n",
"2 NaN NaN NaN 0 0 0 31 \n",
"3 NaN NaN NaN 1 1 0 1 \n",
"4 NaN NaN NaN 1 0 1 2 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"calendar_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>item_id</th>\n",
" <th>dept_id</th>\n",
" <th>cat_id</th>\n",
" <th>store_id</th>\n",
" <th>state_id</th>\n",
" <th>d_1</th>\n",
" <th>d_2</th>\n",
" <th>d_3</th>\n",
" <th>d_4</th>\n",
" <th>...</th>\n",
" <th>d_1904</th>\n",
" <th>d_1905</th>\n",
" <th>d_1906</th>\n",
" <th>d_1907</th>\n",
" <th>d_1908</th>\n",
" <th>d_1909</th>\n",
" <th>d_1910</th>\n",
" <th>d_1911</th>\n",
" <th>d_1912</th>\n",
" <th>d_1913</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>HOBBIES_1_001_CA_1_validation</td>\n",
" <td>HOBBIES_1_001</td>\n",
" <td>HOBBIES_1</td>\n",
" <td>HOBBIES</td>\n",
" <td>CA_1</td>\n",
" <td>CA</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>HOBBIES_1_002_CA_1_validation</td>\n",
" <td>HOBBIES_1_002</td>\n",
" <td>HOBBIES_1</td>\n",
" <td>HOBBIES</td>\n",
" <td>CA_1</td>\n",
" <td>CA</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>HOBBIES_1_003_CA_1_validation</td>\n",
" <td>HOBBIES_1_003</td>\n",
" <td>HOBBIES_1</td>\n",
" <td>HOBBIES</td>\n",
" <td>CA_1</td>\n",
" <td>CA</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>HOBBIES_1_004_CA_1_validation</td>\n",
" <td>HOBBIES_1_004</td>\n",
" <td>HOBBIES_1</td>\n",
" <td>HOBBIES</td>\n",
" <td>CA_1</td>\n",
" <td>CA</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>7</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>HOBBIES_1_005_CA_1_validation</td>\n",
" <td>HOBBIES_1_005</td>\n",
" <td>HOBBIES_1</td>\n",
" <td>HOBBIES</td>\n",
" <td>CA_1</td>\n",
" <td>CA</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 1919 columns</p>\n",
"</div>"
],
"text/plain": [
" id item_id dept_id cat_id store_id \\\n",
"0 HOBBIES_1_001_CA_1_validation HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 \n",
"1 HOBBIES_1_002_CA_1_validation HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 \n",
"2 HOBBIES_1_003_CA_1_validation HOBBIES_1_003 HOBBIES_1 HOBBIES CA_1 \n",
"3 HOBBIES_1_004_CA_1_validation HOBBIES_1_004 HOBBIES_1 HOBBIES CA_1 \n",
"4 HOBBIES_1_005_CA_1_validation HOBBIES_1_005 HOBBIES_1 HOBBIES CA_1 \n",
"\n",
" state_id d_1 d_2 d_3 d_4 ... d_1904 d_1905 d_1906 d_1907 d_1908 \\\n",
"0 CA 0 0 0 0 ... 1 3 0 1 1 \n",
"1 CA 0 0 0 0 ... 0 0 0 0 0 \n",
"2 CA 0 0 0 0 ... 2 1 2 1 1 \n",
"3 CA 0 0 0 0 ... 1 0 5 4 1 \n",
"4 CA 0 0 0 0 ... 2 1 1 0 1 \n",
"\n",
" d_1909 d_1910 d_1911 d_1912 d_1913 \n",
"0 1 3 0 1 1 \n",
"1 1 0 0 0 0 \n",
"2 1 0 1 1 1 \n",
"3 0 1 3 7 2 \n",
"4 1 2 2 2 4 \n",
"\n",
"[5 rows x 1919 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sales_train_validation_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>store_id</th>\n",
" <th>item_id</th>\n",
" <th>wm_yr_wk</th>\n",
" <th>sell_price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>CA_1</td>\n",
" <td>HOBBIES_1_001</td>\n",
" <td>11325</td>\n",
" <td>9.58</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>CA_1</td>\n",
" <td>HOBBIES_1_001</td>\n",
" <td>11326</td>\n",
" <td>9.58</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>CA_1</td>\n",
" <td>HOBBIES_1_001</td>\n",
" <td>11327</td>\n",
" <td>8.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>CA_1</td>\n",
" <td>HOBBIES_1_001</td>\n",
" <td>11328</td>\n",
" <td>8.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>CA_1</td>\n",
" <td>HOBBIES_1_001</td>\n",
" <td>11329</td>\n",
" <td>8.26</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" store_id item_id wm_yr_wk sell_price\n",
"0 CA_1 HOBBIES_1_001 11325 9.58\n",
"1 CA_1 HOBBIES_1_001 11326 9.58\n",
"2 CA_1 HOBBIES_1_001 11327 8.26\n",
"3 CA_1 HOBBIES_1_001 11328 8.26\n",
"4 CA_1 HOBBIES_1_001 11329 8.26"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sell_prices_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Process the data"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"TOTAL_ITEMS = 5000\n",
"# select a random subset of items\n",
"if USE_SUBSET:\n",
" np.random.seed(0)\n",
" unique_ids = sales_train_validation_df[\"id\"].unique()\n",
" ids = np.random.choice(sales_train_validation_df[\"id\"].unique(), TOTAL_ITEMS, replace=False)\n",
" sales_train_validation_df_sub = sales_train_validation_df[sales_train_validation_df[\"id\"].isin(ids)]\n",
" item_ids = sales_train_validation_df_sub[\"item_id\"].unique()\n",
" sell_prices_df_sub = sell_prices_df[sell_prices_df[\"item_id\"].isin(item_ids)]\n",
" calendar_df_sub = calendar_df\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"columns_sales_train_validation.head)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The strategy for processing the data is going to be the following. 1) We are going to have X and y where y is the next days sales for a given product. 3) X is made up of 10 previous prices, day of the week, + event types, cat_id, store_id, state_id"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def proc_train_test(sales_train_validation_df: pd.DataFrame, calendar_df: pd.DataFrame, sell_prices_df: pd.DataFrame, context_length: int, test_percentage: float, percentage_omittied: int = 0): #type annotation too long\n",
" \"\"\"\n",
" This function processes the data and returns the training and test data in two ways:\n",
" - undifferentiated: a list of all training and test data (X_train, y_train, X_test, y_test)\n",
" - differentiated: a list of training and test data for each product (X_train_prod, y_train_prod, X_test_prod, y_test_prod)\n",
" where X_train_prod[i] contains a list of all X_train values for the product i with similar grouping for y_train_prod and test\n",
"\n",
" This assumes from the dataframes that\n",
" - sales_train_validation_df:\n",
" - has columns with the format d_1, d_2, ...\n",
" - has columns item_id and store_id\n",
" - calendar_df:\n",
" - wday, month, event_name_1, event_name_2\n",
" - sell_prices_df:\n",
" - item_id, store_id, sell_price\n",
"\n",
" - percentage_omittied: percentage of the data to be omitted from the training data and the test data\n",
" (randomly selected)\n",
"\n",
" Returns:\n",
" - undifferentiated: Tuple of X_train, y_train, X_test, y_test\n",
" - differentiated: Tuple of X_train_prod, y_train_prod, X_test_prod, y_test_prod\n",
" \"\"\"\n",
" np.random.seed(0)\n",
" # First we need to get the training data\n",
" # We will use the first 1913 days as training data and the next\n",
"\n",
" X_train = []\n",
" y_train = []\n",
"\n",
" X_test = []\n",
" y_test = []\n",
"\n",
" # We will also return a second grouping of lists where X_train_prod[i] contains a\n",
" # a list of all X_train values for the product i with similar grouping for y_train_prod and test\n",
" X_train_prod = []\n",
" y_train_prod = []\n",
" X_test_prod = []\n",
" y_test_prod = []\n",
"\n",
"\n",
" # get all days that start with d_ and look for the maximum\n",
" total_days = max([int(x.split(\"_\")[1]) for x in sales_train_validation_df.columns if \"d_\" in x])\n",
" train_days = int(total_days * (1 - test_percentage))\n",
" print(\"train days\", train_days)\n",
" print(\"test days\", total_days - train_days)\n",
" print(\"total days\", total_days)\n",
"\n",
" # Precompute the required data\n",
" calendar_df_dict = calendar_df.set_index(\"d\").to_dict(orient=\"index\")\n",
" sell_prices_dict = sell_prices_df.groupby([\"item_id\", \"store_id\"])[\"sell_price\"].first().to_dict()\n",
"\n",
" pbar = tqdm(total=len(sales_train_validation_df))\n",
" for _, row in sales_train_validation_df.iterrows():\n",
" item_id = row[\"item_id\"]\n",
" store_id = row[\"store_id\"]\n",
"\n",
" X_train_prod.append([])\n",
" y_train_prod.append([])\n",
" X_test_prod.append([])\n",
" y_test_prod.append([])\n",
"\n",
" pbar.update(1)\n",
"\n",
" valid_size = int((train_days - context_length) * (1 - percentage_omittied))\n",
" valid_js = np.random.choice(range(1, train_days - context_length), valid_size, replace=False)\n",
"\n",
" valid_js = list(valid_js) + list(range(train_days, total_days - context_length))\n",
"\n",
" for j in valid_js:\n",
" x = []\n",
"\n",
" # Add sales values for the previous context_length days\n",
" x.extend(row[f\"d_{j+k}\"] for k in range(context_length))\n",
"\n",
" # Add additional features\n",
" current_day = f\"d_{j+context_length}\"\n",
" calendar_data = calendar_df_dict[current_day]\n",
" x.extend([\n",
" calendar_data[\"wday\"],\n",
" calendar_data[\"month\"],\n",
" store_id,\n",
" calendar_data[\"event_name_1\"],\n",
" calendar_data[\"event_name_2\"],\n",
" sell_prices_dict[(item_id, store_id)],\n",
" item_id\n",
" ])\n",
"\n",
" if j < train_days:\n",
" X_train.append(x)\n",
" y_train.append(row[current_day])\n",
" X_train_prod[-1].append(x)\n",
" y_train_prod[-1].append(row[current_day])\n",
"\n",
" else:\n",
" X_test.append(x)\n",
" y_test.append(row[current_day])\n",
" X_train_prod[-1].append(x)\n",
" y_train_prod[-1].append(row[current_day])\n",
"\n",
" undifferentiated = (X_train, y_train, X_test, y_test)\n",
" differentiated = (X_train_prod, y_train_prod, X_test_prod, y_test_prod)\n",
" return undifferentiated, differentiated"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train days 1874\n",
"test days 39\n",
"total days 1913\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████| 5000/5000 [00:04<00:00, 1213.90it/s]\n"
]
}
],
"source": [
"if PROCESS_FROM_SCRATCH:\n",
" undifferentiated, differentiated = proc_train_test(sales_train_validation_df_sub, calendar_df, sell_prices_df_sub, CONTEXT_LENGTH, 0.02, 0.99)\n",
" X_train, y_train, X_test, y_test = undifferentiated\n",
" X_train_prod, y_train_prod, X_test_prod, y_test_prod = differentiated\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(90000, 90000, 95000, 95000)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(X_train), len(y_train), len(X_test), len(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"COL_NAMES = [\n",
" f\"day_{i}\" for i in range(1, CONTEXT_LENGTH+1)\n",
"] + [\"wday\", \"month\", \"store_id\", \"event_name_1\", \"event_name_2\", \"sell_price\", \"item_id\"]\n",
"\n",
"CAT_COLS = [\"store_id\", \"event_name_1\", \"event_name_2\", \"item_id\", \"wday\", \"month\"]\n",
"CAT_COLS_IDX = [COL_NAMES.index(col) for col in CAT_COLS]\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"X_train_df = pd.DataFrame(X_train)\n",
"X_test_df = pd.DataFrame(X_test)\n",
"y_test_df = pd.DataFrame(y_test)\n",
"y_train_df = pd.DataFrame(y_train)\n",
"\n",
"X_train_df.columns = COL_NAMES\n",
"X_test_df.columns = COL_NAMES"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>day_1</th>\n",
" <th>day_2</th>\n",
" <th>day_3</th>\n",
" <th>day_4</th>\n",
" <th>day_5</th>\n",
" <th>day_6</th>\n",
" <th>day_7</th>\n",
" <th>day_8</th>\n",
" <th>day_9</th>\n",
" <th>day_10</th>\n",
" <th>...</th>\n",
" <th>day_18</th>\n",
" <th>day_19</th>\n",
" <th>day_20</th>\n",
" <th>wday</th>\n",
" <th>month</th>\n",
" <th>store_id</th>\n",
" <th>event_name_1</th>\n",
" <th>event_name_2</th>\n",
" <th>sell_price</th>\n",
" <th>item_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>CA_1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>2.98</td>\n",
" <td>HOBBIES_1_005</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>10</td>\n",
" <td>CA_1</td>\n",
" <td>EidAlAdha</td>\n",
" <td>NaN</td>\n",
" <td>2.98</td>\n",
" <td>HOBBIES_1_005</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>CA_1</td>\n",
" <td>SuperBowl</td>\n",
" <td>NaN</td>\n",
" <td>2.98</td>\n",
" <td>HOBBIES_1_005</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>CA_1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>2.98</td>\n",
" <td>HOBBIES_1_005</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>CA_1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>2.98</td>\n",
" <td>HOBBIES_1_005</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89995</th>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>WI_3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.00</td>\n",
" <td>FOODS_3_827</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89996</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>WI_3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.00</td>\n",
" <td>FOODS_3_827</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89997</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>WI_3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.00</td>\n",
" <td>FOODS_3_827</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89998</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>8</td>\n",
" <td>WI_3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.00</td>\n",
" <td>FOODS_3_827</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89999</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>11</td>\n",
" <td>WI_3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>1.00</td>\n",
" <td>FOODS_3_827</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>90000 rows × 27 columns</p>\n",
"</div>"
],
"text/plain": [
" day_1 day_2 day_3 day_4 day_5 day_6 day_7 day_8 day_9 day_10 \\\n",
"0 2 0 0 4 1 0 6 0 0 0 \n",
"1 3 3 1 1 1 0 0 3 0 0 \n",
"2 0 0 1 2 1 1 1 1 1 1 \n",
"3 0 2 0 1 3 2 1 1 2 2 \n",
"4 0 0 0 0 0 0 0 0 0 0 \n",
"... ... ... ... ... ... ... ... ... ... ... \n",
"89995 5 5 0 2 3 0 1 4 2 0 \n",
"89996 0 0 0 0 0 0 0 0 0 0 \n",
"89997 0 0 0 0 0 0 0 0 0 0 \n",
"89998 0 0 0 0 0 0 0 0 0 0 \n",
"89999 0 0 0 0 0 0 0 0 0 0 \n",
"\n",
" ... day_18 day_19 day_20 wday month store_id event_name_1 \\\n",
"0 ... 1 2 0 2 2 CA_1 NaN \n",
"1 ... 3 0 1 4 10 CA_1 EidAlAdha \n",
"2 ... 0 4 1 2 2 CA_1 SuperBowl \n",
"3 ... 2 0 2 4 4 CA_1 NaN \n",
"4 ... 0 0 0 4 3 CA_1 NaN \n",
"... ... ... ... ... ... ... ... ... \n",
"89995 ... 1 9 0 2 5 WI_3 NaN \n",
"89996 ... 0 0 0 1 3 WI_3 NaN \n",
"89997 ... 0 0 0 4 1 WI_3 NaN \n",
"89998 ... 0 0 0 2 8 WI_3 NaN \n",
"89999 ... 0 0 0 2 11 WI_3 NaN \n",
"\n",
" event_name_2 sell_price item_id \n",
"0 NaN 2.98 HOBBIES_1_005 \n",
"1 NaN 2.98 HOBBIES_1_005 \n",
"2 NaN 2.98 HOBBIES_1_005 \n",
"3 NaN 2.98 HOBBIES_1_005 \n",
"4 NaN 2.98 HOBBIES_1_005 \n",
"... ... ... ... \n",
"89995 NaN 1.00 FOODS_3_827 \n",
"89996 NaN 1.00 FOODS_3_827 \n",
"89997 NaN 1.00 FOODS_3_827 \n",
"89998 NaN 1.00 FOODS_3_827 \n",
"89999 NaN 1.00 FOODS_3_827 \n",
"\n",
"[90000 rows x 27 columns]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"X_train_df"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>day_1</th>\n",
" <th>day_2</th>\n",
" <th>day_3</th>\n",
" <th>day_4</th>\n",
" <th>day_5</th>\n",
" <th>day_6</th>\n",
" <th>day_7</th>\n",
" <th>day_8</th>\n",
" <th>day_9</th>\n",
" <th>day_10</th>\n",
" <th>...</th>\n",
" <th>day_18</th>\n",
" <th>day_19</th>\n",
" <th>day_20</th>\n",
" <th>wday</th>\n",
" <th>month</th>\n",
" <th>store_id</th>\n",
" <th>event_name_1</th>\n",
" <th>event_name_2</th>\n",
" <th>sell_price</th>\n",
" <th>item_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>4</td>\n",
" <td>2.98</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>2.98</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>26</td>\n",
" <td>4</td>\n",
" <td>2.98</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>4</td>\n",
" <td>2.98</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>4</td>\n",
" <td>2.98</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 27 columns</p>\n",
"</div>"
],
"text/plain": [
" day_1 day_2 day_3 day_4 day_5 day_6 day_7 day_8 day_9 day_10 ... \\\n",
"0 2 0 0 4 1 0 6 0 0 0 ... \n",
"1 3 3 1 1 1 0 0 3 0 0 ... \n",
"2 0 0 1 2 1 1 1 1 1 1 ... \n",
"3 0 2 0 1 3 2 1 1 2 2 ... \n",
"4 0 0 0 0 0 0 0 0 0 0 ... \n",
"\n",
" day_18 day_19 day_20 wday month store_id event_name_1 event_name_2 \\\n",
"0 1 2 0 1 1 0 30 4 \n",
"1 3 0 1 3 9 0 6 4 \n",
"2 0 4 1 1 1 0 26 4 \n",
"3 2 0 2 3 3 0 30 4 \n",
"4 0 0 0 3 2 0 30 4 \n",
"\n",
" sell_price item_id \n",
"0 2.98 0 \n",
"1 2.98 0 \n",
"2 2.98 0 \n",
"3 2.98 0 \n",
"4 2.98 0 \n",
"\n",
"[5 rows x 27 columns]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Encode the categorical columns as numbers\n",
"from sklearn.preprocessing import LabelEncoder\n",
"# Get only label of item_id\n",
"X_train_df[\"item_id\"] = X_train_df[\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
"X_test_df[\"item_id\"] = X_test_df[\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
"\n",
"\n",
"label_encoders = {}\n",
"for col in CAT_COLS:\n",
" le = LabelEncoder()\n",
" X_train_df[col] = le.fit_transform(X_train_df[col])\n",
" X_test_df[col] = le.transform(X_test_df[col])\n",
" label_encoders[col] = le\n",
"\n",
"\n",
"X_train_prod_processed = []\n",
"X_test_prod_processed = []\n",
"for i in range(len(X_train_prod)):\n",
" X_train_prod_processed.append(pd.DataFrame(X_train_prod[i], columns=COL_NAMES))\n",
" X_test_prod_processed.append(pd.DataFrame(X_test_prod[i], columns=COL_NAMES))\n",
" X_train_prod_processed[-1][\"item_id\"] = X_train_prod_processed[-1][\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
" X_test_prod_processed[-1][\"item_id\"] = X_test_prod_processed[-1][\"item_id\"].apply(lambda x: x.split(\"_\")[1])\n",
" for col in CAT_COLS:\n",
" X_train_prod_processed[-1][col] = label_encoders[col].transform(X_train_prod_processed[-1][col])\n",
" X_test_prod_processed[-1][col] = label_encoders[col].transform(X_test_prod_processed[-1][col])\n",
"\n",
"X_train_df.head()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PPC"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### \"Standard PPCs\""
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def max_ppc(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], number=0, name=\"\") -> None:\n",
" # rpeat y_true to match the shape of y_samples\n",
" max_ppc = np.max(y_samples, axis=1)\n",
" true_max = np.max(y_true)\n",
"\n",
" return max_ppc.flatten(), true_max.flatten(), \"max_ppc\"\n",
"\n",
"def quantile_ppc(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], quantile=0.5, number=0, name=\"\") -> None:\n",
" # rpeat y_true to match the shape of y_samples\n",
" q = np.quantile(y_samples, quantile, axis=1)\n",
" true_q = np.quantile(y_true, quantile)\n",
" return q.flatten(), true_q.flatten(), f\"quantile_ppc_{quantile}\"\n",
"\n",
"def zeros(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], number=0, name=\"\") -> None:\n",
" \"Count the number of zeros in the samples\"\n",
" zeros = np.sum(y_samples < 0.1, axis=1)\n",
" true_zeros = np.sum(y_true < 0.1)\n",
"\n",
" return zeros.flatten(), true_zeros.flatten(), \"zeros\"\n",
"\n",
"def percentage_zeros(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], number=0, name=\"\") -> None:\n",
" \"Count the number of zeros in the samples\"\n",
" zeros = np.mean(y_samples < 0.1, axis=1)\n",
" true_zeros = np.mean(y_true < 0.1)\n",
"\n",
" return zeros.flatten(), true_zeros.flatten(), \"percentage_zeros\""
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def plot_ppcs(y_true: Float[Array, \"batch y_dim\"], y_samples: Float[Array, \"samples batch y_dim\"], ppcs: List[Callable],\n",
" number=0, name=\"\") -> None:\n",
" # plot the distribution of\n",
"\n",
" for ppc in ppcs:\n",
" ppc(y_true, y_samples, number=number, name=name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### \"Complex PPCs\""
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def plot_model_comparisons(data, y_true, figsize=(12, 8), model_names=None):\n",
" \"\"\"\n",
" Plots model predictions against true values for each day.\n",
"\n",
" :param data: numpy array of shape [models, samples, days] containing model predictions\n",
" :param y_true: array of shape [days] containing the true values\n",
" :param figsize: tuple indicating the size of the figure\n",
" \"\"\"\n",
" sns.set(style=\"whitegrid\")\n",
" models, samples, days = data.shape\n",
"\n",
" # Create a figure and axis object\n",
" fig, ax = plt.subplots(figsize=figsize)\n",
"\n",
" # We will transform the data to a format suitable for seaborn\n",
" # Create a DataFrame with model, day, and sample values\n",
" plot_data = []\n",
" if model_names is None:\n",
" model_names = [f\"Model {i}\" for i in range(models)]\n",
"\n",
" for model_idx in range(models):\n",
" for day_idx in range(days):\n",
" for sample_idx in range(samples):\n",
" plot_data.append({\n",
" \"Day\": day_idx,\n",
" \"Value\": data[model_idx, sample_idx, day_idx],\n",
" \"Model\": model_names[model_idx]\n",
" })\n",
"\n",
" import pandas as pd\n",
" plot_data = pd.DataFrame(plot_data)\n",
"\n",
" # Use seaborn to plot the boxplots\n",
" sns.boxplot(x=\"Day\", y=\"Value\", hue=\"Model\", data=plot_data, ax=ax, width=0.6)\n",
"\n",
" # Plot true values\n",
" plt.plot(y_true, 'o', color='red', label='True Values')\n",
"\n",
" # Setting labels and title\n",
" plt.xticks(ticks=np.arange(days), labels=[f\"Day {i+1}\" for i in range(days)])\n",
" plt.xlabel('Days')\n",
" plt.ylabel('Values')\n",
" plt.title('Model Predictions vs. True Values')\n",
" plt.legend()\n",
"\n",
" # Show the plot\n",
" plt.show()\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def save_results_to_pkl(results: dict, dir_name, name):\n",
" if not Path(dir_name).exists():\n",
" Path(dir_name).mkdir(parents=True)\n",
"\n",
" path = Path(dir_name) / f\"{name}.pkl\"\n",
" with open(path, \"wb\") as f:\n",
" pkl.dump(results, f)\n",
"\n",
"\n",
"def load_results_from_pkl(dir_name, name):\n",
" path = Path(dir_name) / f\"{name}.pkl\"\n",
" with open(path, \"rb\") as f:\n",
" results = pkl.load(f)\n",
" return results\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# Simple helper function to train a model and plot ppcs\n",
"\n",
"def get_ppcs(y_samples, X_test, y_test, ppcs, number=0, name=\"\") -> None:\n",
" \"\"\"\n",
" Returns a dictionary with the samples and the true values for each ppc\n",
" the dictionary a\n",
" \"\"\"\n",
" y_samples = np.array(y_samples)\n",
" #y_samples = np.maximum(y_samples, 0)\n",
" #y_samples = np.round(y_samples, 0)\n",
"\n",
" ppc_results = {}\n",
" for ppc in ppcs:\n",
" samples, true, name = ppc(y_test, y_samples, number=number, name=name)\n",
" ppc_results[name] = {\"samples\": samples, \"true\": true}\n",
"\n",
" return ppc_results\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" day_1 day_2 day_3 day_4 day_5 day_6 day_7 day_8 day_9 day_10 ... \\\n",
"0 2 0 0 4 1 0 6 0 0 0 ... \n",
"1 3 3 1 1 1 0 0 3 0 0 ... \n",
"2 0 0 1 2 1 1 1 1 1 1 ... \n",
"3 0 2 0 1 3 2 1 1 2 2 ... \n",
"4 0 0 0 0 0 0 0 0 0 0 ... \n",
"\n",
" day_18 day_19 day_20 wday month store_id event_name_1 event_name_2 \\\n",
"0 1 2 0 1 1 0 30 4 \n",
"1 3 0 1 3 9 0 6 4 \n",
"2 0 4 1 1 1 0 26 4 \n",
"3 2 0 2 3 3 0 30 4 \n",
"4 0 0 0 3 2 0 30 4 \n",
"\n",
" sell_price item_id \n",
"0 2.98 0 \n",
"1 2.98 0 \n",
"2 2.98 0 \n",
"3 2.98 0 \n",
"4 2.98 0 \n",
"\n",
"[5 rows x 27 columns]\n"
]
}
],
"source": [
"print(X_train_df.head())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"EVAL_VALUES = 10_000 #len(X_test_df)\n",
"np.random.seed(0)\n",
"\n",
"eval_idx = np.random.choice(len(X_test_df), EVAL_VALUES, replace=False)\n",
"\n",
"X_train_np = X_train_df.values\n",
"X_test_np = X_test_df.values[eval_idx]\n",
"\n",
"y_train_np = y_train_df.values + np.random.normal(0, 0.01, y_train_df.shape)\n",
"y_test_np = y_test_df.values[eval_idx]\n",
"\n",
"# change to float to prevent errors\n",
"y_train_np = y_train_np.astype(np.float32)\n",
"y_test_np = y_test_np.astype(np.float32)\n",
"\n",
"dataset = {\n",
" \"X_train\": X_train_np,\n",
" \"X_test\": X_test_np,\n",
" \"y_train\": y_train_np,\n",
" \"y_test\": y_test_np,\n",
" \"col_names\": COL_NAMES,\n",
" \"cat_cols\": CAT_COLS,\n",
"}\n",
"\n",
"with open(\"dataset.pkl\", \"wb\") as f:\n",
" pkl.dump(dataset, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"\n",
"#MODEL_CLASSES = MODEL_CLASSES[-1:]\n",
"#NAMES = NAMES[-1:]\n",
"\n",
"NUM_SAMPLES = 100\n",
"HYPERS = [\n",
" {\"subsample\": 0.20, \"subsample_freq\": 1, \"verbose\": 0, \"num_leaves\":129, \"learning_rate\":0.5,\n",
" \"sde_manual_hyperparams\": {\"hyperparam_max\": 10}},\n",
" {},\n",
" {}\n",
"]\n",
"MODEL_CLASSES = [Treeffuser]\n",
"NAMES = [\"Treeffuser\"]\n",
"HYPERS = [\n",
" {\"n_estimators\": 3000, \"learning_rate\": 0.1, \"num_leaves\": 32, \"early_stopping_rounds\": 50, \"n_repeats\": 10},\n",
"]\n",
"\n",
"results = []\n",
"for i in range(len(MODEL_CLASSES)):\n",
" model_cls = MODEL_CLASSES[i]\n",
" #model = BayesOptProbabilisticModel(model_cls, n_iter_bayes_opt=20, frac_validation=0.01)\n",
" model = model_cls(**HYPERS[i])\n",
"\n",
"\n",
" if False and model_cls == NGBoostPoisson:\n",
" # shuffle the data\n",
" np.random.seed(0)\n",
" idx = np.random.permutation(len(X_train_np))\n",
" X_train_np_ngb = X_train_np[idx]\n",
" y_train_np_ngb = y_train_np[idx].astype(np.int32)\n",
" model.fit(X_train_np_ngb, y_train_np_ngb)\n",
"\n",
" elif False and model_cls == NGBoostGaussian:\n",
" y_train_np_ngb = y_train_np + np.random.normal(0, 3, y_train_np.shape)\n",
" # rescale\n",
" #y_train_np_ngb = (y_train_np_ngb - np.mean(y_train_np_ngb)) / np.std(y_train_np_ngb)\n",
" model.fit(X_train_np, y_train_np_ngb)\n",
"\n",
" else:\n",
" model.fit(X_train_np, y_train_np)\n",
"\n",
" results.append({\n",
" \"model\": model,\n",
" \"model_name\": NAMES[i]\n",
" })\n",
"\n",
" save_results_to_pkl(results, \"m5\", \"results.pkl\")\n",
"\n",
"\n",
"results = load_results_from_pkl(\"m5\", \"results.pkl\")\n",
"\n",
"\n",
"# Save the results\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: Treeffuser\n"
]
}
],
"source": [
"\n",
"for i, result in enumerate(results):\n",
" model = result[\"model\"]\n",
" model_name = result[\"model_name\"]\n",
" print(f\"Model: {model_name}\")\n",
" y_samples = model.sample(X_test_np, NUM_SAMPLES)\n",
" results[i][\"y_samples\"] = y_samples"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"results_no_model = []\n",
"for result in results:\n",
" results_no_model.append({k: v for k, v in result.items() if k != \"model\"})\n",
"\n",
"# Don't uncomment or will overwrite the results\n",
"#save_results_to_pkl(results, \"m5\", \"results_final.pkl\")\n",
"#save_results_to_pkl(results_no_model, \"m5\", \"results_no_model_final.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" results_no_model = load_results_from_pkl(\"m5\", \"results_final.pkl\")\n",
"except:\n",
" results_no_model = results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can actually fit some of the models"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
}
],
"source": [
"ppcs = [max_ppc, zeros, percentage_zeros] + [partial(quantile_ppc, quantile=q) for q in [0.1, 0.5, 0.9, 0.99, 0.999]]\n",
"\n",
"for i, model_cls in enumerate(MODEL_CLASSES):\n",
" print(i)\n",
" ppc_results = get_ppcs(\n",
" y_samples=results_no_model[i][\"y_samples\"],\n",
" X_test=X_test_np,\n",
" y_test=y_test_np,\n",
" ppcs=ppcs,\n",
" number=i,\n",
" name=model_cls.__name__\n",
" )\n",
" results[i][\"ppc_results\"] = ppc_results\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plot the PPCs"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# Titles for plots\n",
"ppc_tiles = {\n",
" \"max_ppc\": \"$\\max$\",\n",
" \"zeros\": r\"$\\text{zeros}$\",\n",
" \"quantile_ppc_0.99\": \"$q_{0.99}$\",\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def set_plot_style():\n",
" \"\"\"\n",
" Sets a common plotting style for all of the figures that will be\n",
" used in the final paper.\n",
" \"\"\"\n",
"\n",
" # no grid but white with pretty ticks\n",
"\n",
"\n",
" # use latex font by default\n",
" plt.rc(\"text\", usetex=False)\n",
" plt.rc(\"font\", family=\"serif\")\n",
" sns.set_style(\"white\")\n",
"\n",
" # make it ready for a presentation\n",
" sns.set_context(\"talk\")\n",
"set_plot_style()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [204. 182. 267. 127. 123. 134. 205. 223. 223. 194. 191. 184. 101. 105.\n",
" 155. 166. 194. 130. 122. 139. 147. 216. 220. 144. 183. 237. 147. 182.\n",
" 156. 108. 243. 247. 118. 154. 184. 114. 111. 141. 190. 289. 140. 254.\n",
" 159. 241. 187. 119. 167. 222. 152. 96. 140. 177. 178. 137. 200. 233.\n",
" 152. 292. 239. 145. 263. 186. 112. 172. 219. 168. 136. 132. 162. 159.\n",
" 247. 109. 154. 120. 186. 166. 172. 181. 149. 140. 204. 146. 247. 204.\n",
" 269. 148. 124. 133. 101. 170. 150. 275. 138. 117. 169. 110. 345. 161.\n",
" 108. 104.]\n",
"title max_ppc\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [5693 5628 5649 5628 5680 5641 5678 5637 5657 5694 5597 5599 5685 5607\n",
" 5676 5528 5707 5590 5655 5630 5618 5563 5675 5668 5624 5655 5671 5625\n",
" 5647 5610 5656 5634 5696 5759 5741 5684 5648 5617 5632 5564 5616 5656\n",
" 5635 5625 5682 5596 5696 5661 5726 5645 5644 5622 5660 5675 5653 5671\n",
" 5682 5702 5670 5640 5562 5666 5645 5703 5690 5566 5693 5641 5659 5676\n",
" 5587 5644 5641 5669 5682 5612 5632 5644 5670 5684 5643 5665 5679 5666\n",
" 5666 5657 5695 5681 5680 5661 5679 5684 5681 5686 5707 5701 5682 5606\n",
" 5652 5636]\n",
"title zeros\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 1.]\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0.]\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0.]\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [3. 3. 3. 3. 4. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
" 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
" 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
" 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
" 3. 3. 3. 3.]\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [15. 14. 14. 14. 13. 13. 14. 14. 15. 14. 14. 15. 13. 14. 14. 15. 14. 14.\n",
" 14. 14. 14. 14. 13. 14. 13. 13. 15. 15. 14. 14. 14. 14. 14. 15. 14. 13.\n",
" 14. 14. 15. 13. 13. 13. 14. 14. 14. 14. 14. 14. 14. 15. 14. 14. 13. 15.\n",
" 14. 14. 14. 13. 15. 14. 14. 15. 14. 14. 14. 14. 14. 14. 14. 15. 14. 14.\n",
" 14. 14. 13. 14. 14. 15. 14. 14. 14. 13. 13. 14. 14. 14. 14. 15. 15. 13.\n",
" 14. 13. 13. 14. 14. 13. 14. 14. 15. 14.]\n",
"title quantile_ppc_0.99\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"samples [33. 40. 36. 32. 39. 39. 39. 35. 38. 39. 39. 39. 37. 45. 35. 42. 37. 41.\n",
" 42. 34. 38. 34. 40. 40. 45. 39. 39. 37. 37. 42. 35. 38. 36. 38. 36. 38.\n",
" 37. 37. 41. 42. 36. 37. 40. 42. 37. 38. 32. 39. 40. 36. 34. 39. 38. 40.\n",
" 40. 40. 40. 37. 36. 34. 37. 35. 42. 37. 48. 38. 35. 37. 42. 41. 42. 42.\n",
" 37. 34. 41. 34. 34. 41. 39. 37. 33. 36. 39. 40. 38. 38. 35. 38. 36. 38.\n",
" 46. 32. 34. 41. 40. 37. 41. 37. 38. 39.]\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def proc_title(title):\n",
" if title in ppc_tiles:\n",
" print(\"title\", title)\n",
" return ppc_tiles[title]\n",
"\n",
" x = title.replace(\"_\", \" \").capitalize()\n",
" x = x.replace(\"ppc\", \"\")\n",
" return x\n",
"\n",
"ppc_number = len(ppcs)\n",
"ppc_names = results[0][\"ppc_results\"].keys()\n",
"\n",
"for ppc_name in ppc_names:\n",
" fig, ax = plt.subplots()\n",
" for i, res in enumerate(results):\n",
" model_name = res[\"model_name\"]\n",
" samples = res[\"ppc_results\"][ppc_name][\"samples\"]\n",
" # make int\n",
" samples = np.maximum(samples, 0)\n",
" samples = np.round(samples)\n",
" true = res[\"ppc_results\"][ppc_name][\"true\"]\n",
" n_unique_samples = len(np.unique(samples))\n",
" discrete = n_unique_samples < 20\n",
"\n",
" print(\"samples\", samples)\n",
"\n",
" # plot a histogram of the samples but with integers (use nice binning)\n",
" if ppc_name == \"max_ppc\":\n",
" binwidth = 70\n",
" else:\n",
" binwidth = None\n",
" sns.histplot(samples, ax=ax, label=f\"{model_name}\" + r\" $\\hat{p}$\", discrete=discrete, stat=\"density\", binwidth=binwidth)\n",
" if i == 0:\n",
" ax.axvline(true, color=\"red\", label=\"Observed Value\")\n",
" ax.set_title(proc_title(f\"{ppc_name}\"))\n",
" ax.legend()\n",
"\n",
" if n_unique_samples < 2:\n",
" min_val = np.min(samples) - 1\n",
" max_val = np.max(samples) + 1\n",
"\n",
" ax.set_xlim(min_val, max_val)\n",
"\n",
"\n",
"\n",
" # save the figure\n",
" fig.savefig(f\"m5/{ppc_name}.png\", dpi=100)\n",
" # save as pdf\n",
" fig.savefig(f\"m5/{ppc_name}.pdf\")\n",
"\n",
" #max_x = true * 5\n",
" #max_samples = np.max(samples)\n",
" #if max_samples > max_x:\n",
" # ax.set_xlim(0, max_x)\n",
"\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x3acbe1d00>"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i, result in enumerate(results):\n",
" samples = result[\"y_samples\"][0]\n",
" samples = np.round(samples) + (i+1)/4\n",
" sns.histplot(samples.flatten(), stat=\"density\", label=result[\"model_name\"])\n",
"\n",
"sns.histplot(y_test_np.flatten(), stat=\"density\", color=\"red\", label=\"true\")\n",
"plt.xlim(0, 10)\n",
"plt.legend()\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean shape (10,)\n",
"lower shape (10,)\n",
"upper shape (10,)\n",
"y_true shape (10,)\n",
"samples shape (1, 100, 10)\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean shape (10,)\n",
"lower shape (10,)\n",
"upper shape (10,)\n",
"y_true shape (10,)\n",
"samples shape (1, 100, 10)\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean shape (10,)\n",
"lower shape (10,)\n",
"upper shape (10,)\n",
"y_true shape (10,)\n",
"samples shape (1, 100, 10)\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"np.random.seed(0)\n",
"\n",
"NUM_PRODS_TO_PLOT = 3\n",
"DAYS_TO_PLOT = 10\n",
"QUATILE = 0.9\n",
"NUM_SAMPLES = 100\n",
"\n",
"prods_to_plot = np.random.choice(range(len(X_train_prod_processed)), NUM_PRODS_TO_PLOT)\n",
"\n",
"means_to_plot = []\n",
"lower_q_s = []\n",
"upper_q_s = []\n",
"\n",
"prod_dict = {}\n",
"\n",
"for res in results:\n",
" model = res[\"model\"]\n",
"\n",
" for i in prods_to_plot:\n",
" if i not in prod_dict:\n",
" prod_dict[i] = {\n",
" \"means_to_plot\": [],\n",
" \"lower_q_s\": [],\n",
" \"upper_q_s\": [],\n",
" \"samples\": []\n",
" }\n",
"\n",
" X_prod_proc_i = X_train_prod_processed[i][-DAYS_TO_PLOT:]\n",
" y_prod_proc_i = y_train_prod[i][-DAYS_TO_PLOT:]\n",
"\n",
" samples = model.sample(X_prod_proc_i.values, NUM_SAMPLES)\n",
" samples = samples.astype(int)\n",
" samples = np.maximum(samples, 0)\n",
"\n",
" means = np.mean(samples, axis=0)\n",
" lower_q = np.quantile(samples, 1-QUATILE, axis=0)\n",
" upper_q = np.quantile(samples, QUATILE, axis=0)\n",
"\n",
" prod_dict[i][\"means_to_plot\"].append(means)\n",
" prod_dict[i][\"lower_q_s\"].append(lower_q)\n",
" prod_dict[i][\"upper_q_s\"].append(upper_q)\n",
" prod_dict[i][\"samples\"].append(samples)\n",
"\n",
"for i in prod_dict:\n",
" means_to_plot = np.array(prod_dict[i][\"means_to_plot\"]).squeeze()\n",
" lower_q_s = np.array(prod_dict[i][\"lower_q_s\"]).squeeze()\n",
" upper_q_s = np.array(prod_dict[i][\"upper_q_s\"]).squeeze()\n",
" samples = np.array(prod_dict[i][\"samples\"]).squeeze(-1)\n",
" y_true = np.array(y_train_prod[i][-DAYS_TO_PLOT:])\n",
"\n",
" model_names = [res[\"model\"].__class__.__name__ for res in results]\n",
"\n",
" print(\"mean shape\", means_to_plot.shape)\n",
" print(\"lower shape\", lower_q_s.shape)\n",
" print(\"upper shape\", upper_q_s.shape)\n",
" print(\"y_true shape\", y_true.shape)\n",
" print(\"samples shape\", samples.shape)\n",
"\n",
" #plot_predictions(y_true, means_to_plot, upper_q_s, lower_q_s, [res[\"model\"].__class__.__name__ for res in results])\n",
" plot_model_comparisons(samples, y_true, model_names=model_names)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Log-likelihood"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Treeffuser has a NLL of {'nll_samples': 1.2175196664428274}\n",
"Model Treeffuser has a CRPS of {'crps_100': 0.6524468580538034}\n"
]
}
],
"source": [
"for res in results:\n",
" model = res[\"model\"]\n",
" name = res[\"model_name\"]\n",
" nll = LogLikelihoodFromSamplesMetric(n_samples=100).compute(model=model, X_test=X_test_np, y_test=y_test_np, samples=res[\"y_samples\"])\n",
" crps = CRPS().compute(model=model, X_test=X_test_np, y_test=y_test_np, samples=res[\"y_samples\"])\n",
" print(f\"Model {name} has a NLL of {nll}\")\n",
" print(f\"Model {name} has a CRPS of {crps}\")"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'mae': 0.9883351454782684, 'rmse': 2.1792834439725737, 'mdae': 0.5293582833519341, 'marpd': 141.8933048278295, 'r2': 0.6827293677776689, 'corr': 0.8279086805309621}\n"
]
},
{
"ename": "AssertionError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[39], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m AccuracyMetric()\u001b[38;5;241m.\u001b[39mcompute(model\u001b[38;5;241m=\u001b[39mmodel, X_test\u001b[38;5;241m=\u001b[39mX_test_np, y_test\u001b[38;5;241m=\u001b[39my_test_np)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(accuracy)\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"\u001b[0;31mAssertionError\u001b[0m: "
]
}
],
"source": [
"accuracy = AccuracyMetric().compute(model=model, X_test=X_test_np, y_test=y_test_np)\n",
"print(accuracy)\n",
"assert False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plot calibration plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def calibration_plot(y_samples: Float[np.ndarray, \"n_samples batch y_dim\"], y_test: Float[np.ndarray, \"batch y_dim\"]) -> None:\n",
" \"\"\"\n",
" We will plot the calibration plot for the model. Essentially, we will plot the\n",
" \"\"\"\n",
" assert y_test.shape[1] == 1, \"Only works for univariate outputs\"\n",
" n_samples = y_samples.shape[0]\n",
" y_samples = np.maximum(y_samples, 0.0)\n",
" y_samples = np.round(y_samples).astype(int)\n",
" y_test = y_test.astype(int)\n",
"\n",
" # Filter out the zeros\n",
" #non_zero_idx = y_test > 0\n",
" #y_test = y_test[non_zero_idx]\n",
" #y_samples = y_samples[:, non_zero_idx]\n",
"\n",
" y_test_expanded = y_test[np.newaxis, :].repeat(n_samples, axis=0)\n",
" prob_of_event = np.mean(y_samples <= y_test_expanded, axis=0)\n",
" prob_of_event_sorted = np.sort(prob_of_event.flatten())\n",
" return np.linspace(0, 1, len(prob_of_event)), prob_of_event_sorted\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_ex = None\n",
"for res in results:\n",
" samples = res[\"y_samples\"]\n",
" model = res[\"model\"]\n",
" x,y = calibration_plot(samples, y_test_np)\n",
" plt.plot(x, y, label=res[\"model_name\"])\n",
" y_ex = y\n",
"\n",
"\n",
"plt.plot([0, 1], [0, 1], linestyle=\"--\", color=\"black\")\n",
"\n",
"plt.xlabel(\"Predicted Probability\")\n",
"plt.ylabel(\"True Probability\")\n",
"plt.legend()\n",
"\n",
"plt.show()\n",
"\n",
"# plot the distribution of y\n",
"\n",
"sns.histplot(y_ex, bins=20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Quantile prediction plot for the model\n",
"def make_quantile_plot(results):\n",
" quantiles = np.linspace(0.9, 0.999, 10)\n",
"\n",
" means = []\n",
" stds = []\n",
"\n",
" for res in results:\n",
" samples = res[\"y_samples\"] # shape [n_samples, batch, y_dim]\n",
" samples = np.maximum(samples, 0)\n",
" samples = np.round(samples)\n",
"\n",
" m, s = [], []\n",
"\n",
" for q in quantiles:\n",
" quantile_samples = np.quantile(samples, q, axis=1)\n",
" mean = np.mean(quantile_samples, axis=0)\n",
" std = np.std(quantile_samples, axis=0)\n",
" m.append(mean)\n",
" s.append(std)\n",
"\n",
" m = np.array(m).squeeze()\n",
" s = np.array(s).squeeze()\n",
" plt.plot(quantiles, m, label=res[\"model_name\"])\n",
" plt.fill_between(quantiles, m - s, m + s, alpha=0.3)\n",
"\n",
" means.append(m)\n",
" stds.append(s)\n",
"\n",
"\n",
" true_quantiles = []\n",
" for q in quantiles:\n",
" true_quantiles.append(np.quantile(y_test_np, q))\n",
"\n",
" plt.plot(quantiles, true_quantiles, label=\"True\")\n",
"\n",
" # set log scale on y axis\n",
"\n",
"\n",
" # set log scale on x axis\n",
" plt.xscale(\"log\")\n",
" plt.legend()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"make_quantile_plot(results)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"testbed.models.treeffuser.Treeffuser"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_cls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment