Created
December 4, 2025 15:45
-
-
Save tteofili/e315844b5e2bab607b6d0c4d5f8a54c3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells" : [ { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:39:49.759345Z", | |
| "start_time" : "2025-12-04T15:39:49.756767Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : "", | |
| "id" : "e635173ef70d8762", | |
| "outputs" : [ ], | |
| "execution_count" : null | |
| }, { | |
| "metadata" : { | |
| "collapsed" : true, | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T14:00:03.890791Z", | |
| "start_time" : "2025-12-04T14:00:03.885325Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "import pandas as pd\n", "import json\n", "import operator\n", "\n", "from ellmer.selfexplainer import SelfExplainer" ], | |
| "id" : "909306fc9640edd9", | |
| "outputs" : [ ], | |
| "execution_count" : 141 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T14:15:49.746740Z", | |
| "start_time" : "2025-12-04T14:15:49.734664Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "def perturb_for_target_label(row, ranked_cols, predict_fn, label, k):\n", " row_c = row.copy()\n", " top_k = ranked_cols[:k]\n", " for col in top_k:\n", " if label == 1:\n", " row_c[col] = \"\"\n", " else:\n", " if col.startswith(\"ltable_\"):\n", " counterpart = \"rtable_\" + col[len(\"ltable_\"):]\n", " elif col.startswith(\"rtable_\"):\n", " counterpart = \"ltable_\" + col[len(\"rtable_\"):]\n", " else:\n", " raise ValueError(f\"Column does not start with valid prefix: {col}\")\n", "\n", " if counterpart in row_c.columns:\n", " row_c[col] = row_c[counterpart]\n", " else:\n", " print(f'no counterpart found for {counterpart}')\n", " pass\n", "\n", " new_pred = predict_fn(row_c)\n", " return row_c, new_pred\n" ], | |
| "id" : "4af866c094810b13", | |
| "outputs" : [ ], | |
| "execution_count" : 175 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:03:04.623161Z", | |
| "start_time" : "2025-12-04T12:03:04.605366Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "llm = SelfExplainer(explanation_granularity='attribute',\n", " temperature=1,\n", " model_name='gpt-5-nano', model_type='azure_openai',\n", " prompts={\"ptse\": {\"er\": \"../ellmer/prompts/er.txt\"}})" ], | |
| "id" : "b56f71faca9a505a", | |
| "outputs" : [ ], | |
| "execution_count" : 18 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:10:56.879137Z", | |
| "start_time" : "2025-12-04T15:10:56.872534Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "def predict_fn(row):\n", " return llm.predict(pd.DataFrame(row))['match_score'].values[0]\n" ], | |
| "id" : "a710f7acb402396a", | |
| "outputs" : [ ], | |
| "execution_count" : 183 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T13:58:37.335056Z", | |
| "start_time" : "2025-12-04T13:58:33.950135Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:03<00:00, 3.37s/it]\u001B[A" ] | |
| }, { | |
| "name" : "stdout", | |
| "output_type" : "stream", | |
| "text" : [ "ltable_name Canon EOS digital camera\n", "rtable_name Canon EOS 90D camera body\n", "dtype: object\n" ] | |
| }, { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "\n" ] | |
| } ], | |
| "execution_count" : 140, | |
| "source" : [ "row = pd.Series({\n", " \"ltable_name\": \"Canon EOS 80D digital camera\",\n", " \"rtable_name\": \"Canon EOS 90D camera body\",\n", "})\n", "\n", "ranked = [\n", " \"ltable_name__80D\",\n", " \"rtable_name__90D\",\n", "]\n", "\n", "perturbed, pred = token_perturb_for_target_label(\n", " row, ranked, predict_fn, label=1, k=1\n", ")\n", "\n", "print(perturbed)\n" ], | |
| "id" : "6e0e039e08657c82" | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:30:45.760222Z", | |
| "start_time" : "2025-12-04T12:30:45.757545Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "with open(\"/Users/tteofili/dev/ellmer/experiments/azure_openai/gpt-5-nano/attribute/cameras_small/20251204/12_43/2_results.json\") as f:\n", " d = json.load(f)" ], | |
| "id" : "20da3c2e2a410071", | |
| "outputs" : [ ], | |
| "execution_count" : 56 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:33:39.917203Z", | |
| "start_time" : "2025-12-04T12:33:39.912517Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "ltuple = pd.json_normalize(d['data'][0]['zs_sample']['ltuple']).add_prefix('ltable_')\n", "rtuple = pd.json_normalize(d['data'][0]['zs_sample']['rtuple']).add_prefix('rtable_')" ], | |
| "id" : "e6b98d295d5adc27", | |
| "outputs" : [ ], | |
| "execution_count" : 66 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T14:03:33.542898Z", | |
| "start_time" : "2025-12-04T14:03:33.531077Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "wdc_cameras_row = pd.concat([ltuple, rtuple], axis=1)\n", "wdc_cameras_row" ], | |
| "id" : "3bba24732497f80f", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_category ltable_cluster_id ltable_brand \\\n", "0 Camera_and_Photo 660483.0 nan \n", "\n", " ltable_title \\\n", "0 Canon Digital Rebel T7i 18-55mm Kit\"@en-US Kit... \n", "\n", " ltable_description ltable_price \\\n", "0 What's included EOS Rebel T7i EF-S 18-55mm 4-... nan \n", "\n", " ltable_specTableContent rtable_category rtable_cluster_id rtable_brand \\\n", "0 nan Camera_and_Photo 3092335.0 nan \n", "\n", " rtable_title \\\n", "0 \"Canon Lens Hood ES-52 for EF 44mm f/2.8 STM\"... \n", "\n", " rtable_description rtable_price \\\n", "0 \" Prevents stray light from entering the lens... nan \n", "\n", " rtable_specTableContent \n", "0 nan " ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>ltable_category</th>\n", " <th>ltable_cluster_id</th>\n", " <th>ltable_brand</th>\n", " <th>ltable_title</th>\n", " <th>ltable_description</th>\n", " <th>ltable_price</th>\n", " <th>ltable_specTableContent</th>\n", " <th>rtable_category</th>\n", " <th>rtable_cluster_id</th>\n", " <th>rtable_brand</th>\n", " <th>rtable_title</th>\n", " <th>rtable_description</th>\n", " <th>rtable_price</th>\n", " <th>rtable_specTableContent</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Camera_and_Photo</td>\n", " <td>660483.0</td>\n", " <td>nan</td>\n", " <td>Canon Digital Rebel T7i 18-55mm Kit\"@en-US Kit...</td>\n", " <td>What's included EOS Rebel T7i EF-S 18-55mm 4-...</td>\n", " <td>nan</td>\n", " <td>nan</td>\n", " <td>Camera_and_Photo</td>\n", " <td>3092335.0</td>\n", " <td>nan</td>\n", " <td>\"Canon Lens Hood ES-52 for EF 44mm f/2.8 STM\"...</td>\n", " <td>\" Prevents stray light from entering the lens...</td>\n", " <td>nan</td>\n", " <td>nan</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 146, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 146 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:31:12.361585Z", | |
| "start_time" : "2025-12-04T12:31:12.358224Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "prediction = pd.json_normalize(d['data'][0]['zs_sample'])['prediction']\n", "prediction" ], | |
| "id" : "2565ccf869a697d5", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ "0 0\n", "Name: prediction, dtype: int64" ] | |
| }, | |
| "execution_count" : 62, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 62 | |
| }, { | |
| "metadata" : { }, | |
| "cell_type" : "code", | |
| "outputs" : [ ], | |
| "execution_count" : null, | |
| "source" : "", | |
| "id" : "9fe431c5ad2e4300" | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:27:55.326675Z", | |
| "start_time" : "2025-12-04T12:27:55.320310Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "self_explanation = pd.json_normalize(d['data'][0]['zs_sample']['saliency'])\n", "self_explanation" ], | |
| "id" : "bcd247cd99207821", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_category ltable_cluster_id ltable_brand ltable_title \\\n", "0 0.25 0.1 0.05 0.4 \n", "\n", " ltable_description ltable_price ltable_specTableContent rtable_category \\\n", "0 0.45 0.05 0.0 0.25 \n", "\n", " rtable_cluster_id rtable_brand rtable_title rtable_description \\\n", "0 0.1 0.05 0.4 0.45 \n", "\n", " rtable_price rtable_specTableContent \n", "0 0.05 0.0 " ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>ltable_category</th>\n", " <th>ltable_cluster_id</th>\n", " <th>ltable_brand</th>\n", " <th>ltable_title</th>\n", " <th>ltable_description</th>\n", " <th>ltable_price</th>\n", " <th>ltable_specTableContent</th>\n", " <th>rtable_category</th>\n", " <th>rtable_cluster_id</th>\n", " <th>rtable_brand</th>\n", " <th>rtable_title</th>\n", " <th>rtable_description</th>\n", " <th>rtable_price</th>\n", " <th>rtable_specTableContent</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.25</td>\n", " <td>0.1</td>\n", " <td>0.05</td>\n", " <td>0.4</td>\n", " <td>0.45</td>\n", " <td>0.05</td>\n", " <td>0.0</td>\n", " <td>0.25</td>\n", " <td>0.1</td>\n", " <td>0.05</td>\n", " <td>0.4</td>\n", " <td>0.45</td>\n", " <td>0.05</td>\n", " <td>0.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 46, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 46 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:31:18.138580Z", | |
| "start_time" : "2025-12-04T12:31:18.127412Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "post_hoc_explanation = pd.json_normalize(d['data'][0]['certa_sample']['saliency'])\n", "post_hoc_explanation" ], | |
| "id" : "ea0e7d561021f797", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_category ltable_cluster_id ltable_brand ltable_title \\\n", "0 0.152993 0.170732 0.13969 0.317073 \n", "\n", " ltable_description ltable_price ltable_specTableContent rtable_category \\\n", "0 0.192905 0.152993 0.152993 0.334812 \n", "\n", " rtable_cluster_id rtable_brand rtable_title rtable_description \\\n", "0 0.412417 0.332594 0.682927 0.40133 \n", "\n", " rtable_price rtable_specTableContent \n", "0 0.334812 0.32816 " ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>ltable_category</th>\n", " <th>ltable_cluster_id</th>\n", " <th>ltable_brand</th>\n", " <th>ltable_title</th>\n", " <th>ltable_description</th>\n", " <th>ltable_price</th>\n", " <th>ltable_specTableContent</th>\n", " <th>rtable_category</th>\n", " <th>rtable_cluster_id</th>\n", " <th>rtable_brand</th>\n", " <th>rtable_title</th>\n", " <th>rtable_description</th>\n", " <th>rtable_price</th>\n", " <th>rtable_specTableContent</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.152993</td>\n", " <td>0.170732</td>\n", " <td>0.13969</td>\n", " <td>0.317073</td>\n", " <td>0.192905</td>\n", " <td>0.152993</td>\n", " <td>0.152993</td>\n", " <td>0.334812</td>\n", " <td>0.412417</td>\n", " <td>0.332594</td>\n", " <td>0.682927</td>\n", " <td>0.40133</td>\n", " <td>0.334812</td>\n", " <td>0.32816</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 63, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 63 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:58:10.509733Z", | |
| "start_time" : "2025-12-04T12:58:10.503045Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "# compare\n", "explanations = pd.concat([self_explanation, post_hoc_explanation], axis=0, keys=['self', 'post_hoc'])\n", "explanations.head()" ], | |
| "id" : "92275ff5b2d6fc4d", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_category ltable_cluster_id ltable_brand ltable_title \\\n", "self 0 0.250000 0.100000 0.05000 0.400000 \n", "post_hoc 0 0.152993 0.170732 0.13969 0.317073 \n", "\n", " ltable_description ltable_price ltable_specTableContent \\\n", "self 0 0.450000 0.050000 0.000000 \n", "post_hoc 0 0.192905 0.152993 0.152993 \n", "\n", " rtable_category rtable_cluster_id rtable_brand rtable_title \\\n", "self 0 0.250000 0.100000 0.050000 0.400000 \n", "post_hoc 0 0.334812 0.412417 0.332594 0.682927 \n", "\n", " rtable_description rtable_price rtable_specTableContent \n", "self 0 0.45000 0.050000 0.00000 \n", "post_hoc 0 0.40133 0.334812 0.32816 " ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th></th>\n", " <th>ltable_category</th>\n", " <th>ltable_cluster_id</th>\n", " <th>ltable_brand</th>\n", " <th>ltable_title</th>\n", " <th>ltable_description</th>\n", " <th>ltable_price</th>\n", " <th>ltable_specTableContent</th>\n", " <th>rtable_category</th>\n", " <th>rtable_cluster_id</th>\n", " <th>rtable_brand</th>\n", " <th>rtable_title</th>\n", " <th>rtable_description</th>\n", " <th>rtable_price</th>\n", " <th>rtable_specTableContent</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>self</th>\n", " <th>0</th>\n", " <td>0.250000</td>\n", " <td>0.100000</td>\n", " <td>0.05000</td>\n", " <td>0.400000</td>\n", " <td>0.450000</td>\n", " <td>0.050000</td>\n", " <td>0.000000</td>\n", " <td>0.250000</td>\n", " <td>0.100000</td>\n", " <td>0.050000</td>\n", " <td>0.400000</td>\n", " <td>0.45000</td>\n", " <td>0.050000</td>\n", " <td>0.00000</td>\n", " </tr>\n", " <tr>\n", " <th>post_hoc</th>\n", " <th>0</th>\n", " <td>0.152993</td>\n", " <td>0.170732</td>\n", " <td>0.13969</td>\n", " <td>0.317073</td>\n", " <td>0.192905</td>\n", " <td>0.152993</td>\n", " <td>0.152993</td>\n", " <td>0.334812</td>\n", " <td>0.412417</td>\n", " <td>0.332594</td>\n", " <td>0.682927</td>\n", " <td>0.40133</td>\n", " <td>0.334812</td>\n", " <td>0.32816</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 109, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 109 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T12:58:15.123491Z", | |
| "start_time" : "2025-12-04T12:58:14.985275Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : "explanations.T.plot(kind='bar', stacked=False, figsize=(12, 6), rot=58)", | |
| "id" : "fcf69878d4338a74", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ "<Axes: >" ] | |
| }, | |
| "execution_count" : 110, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| }, { | |
| "data" : { | |
| "text/plain" : [ "<Figure size 1200x600 with 1 Axes>" ], | |
| "image/png" : "" | |
| }, | |
| "metadata" : { }, | |
| "output_type" : "display_data", | |
| "jetTransient" : { | |
| "display_id" : null | |
| } | |
| } ], | |
| "execution_count" : 110 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T14:16:02.544493Z", | |
| "start_time" : "2025-12-04T14:15:57.953584Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "expl1 = [r[0] for r in sorted(explanations.iloc[0].to_dict().items(), key=operator.itemgetter(1), reverse=True)]\n", "perturb_for_target_label(wdc_cameras_row, expl1, predict_fn, label=int(prediction.values), k=5)" ], | |
| "id" : "2af6f53ef7514183", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/2824996153.py:2: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " perturb_for_target_label(wdc_cameras_row, expl1, predict_fn, label=int(prediction.values), k=5)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:04<00:00, 4.57s/it]\u001B[A\n" ] | |
| }, { | |
| "data" : { | |
| "text/plain" : [ "( ltable_category ltable_cluster_id ltable_brand \\\n", " 0 Camera_and_Photo 660483.0 nan \n", " \n", " ltable_title \\\n", " 0 \"Canon Lens Hood ES-52 for EF 44mm f/2.8 STM\"... \n", " \n", " ltable_description ltable_price \\\n", " 0 \" Prevents stray light from entering the lens... nan \n", " \n", " ltable_specTableContent rtable_category rtable_cluster_id rtable_brand \\\n", " 0 nan Camera_and_Photo 3092335.0 nan \n", " \n", " rtable_title \\\n", " 0 \"Canon Lens Hood ES-52 for EF 44mm f/2.8 STM\"... \n", " \n", " rtable_description rtable_price \\\n", " 0 \" Prevents stray light from entering the lens... nan \n", " \n", " rtable_specTableContent \n", " 0 nan ,\n", " array([1]))" ] | |
| }, | |
| "execution_count" : 176, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 176 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:11:32.767610Z", | |
| "start_time" : "2025-12-04T15:11:01.605846Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "res = []\n", "\n", "k_range = 6\n", "\n", "for i in range(len(explanations)):\n", " expl = explanations.iloc[i]\n", " explainer = expl.name[0]\n", " salient_features = [r[0] for r in sorted(expl.to_dict().items(), key=operator.itemgetter(1), reverse=True)]\n", " for k in range(k_range):\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", " cr = {'prediction': pred, 'perturbed': pert, 'explainer': explainer}\n", " res.append(cr)\n", "\n", "consistency = pd.DataFrame(res)\n" ], | |
| "id" : "3afed76d80884a44", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:01<00:00, 1.99s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:01<00:00, 1.98s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:02<00:00, 2.07s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:02<00:00, 2.44s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:03<00:00, 3.20s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:02<00:00, 2.85s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:01<00:00, 1.89s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:05<00:00, 5.38s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:03<00:00, 3.58s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:01<00:00, 1.17s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:02<00:00, 2.86s/it]\u001B[A\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_5310/493151998.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = perturb_for_target_label(wdc_cameras_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "\n", " 0%| | 0/1 [00:00<?, ?it/s]\u001B[A\n", "100%|██████████| 1/1 [00:01<00:00, 1.66s/it]\u001B[A\n" ] | |
| } ], | |
| "execution_count" : 184 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:25:54.687618Z", | |
| "start_time" : "2025-12-04T15:25:54.674936Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : "consistency[['explainer','prediction']].groupby(['explainer']).sum()/k_range", | |
| "id" : "98f451ed7e6bee27", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " prediction\n", "explainer \n", "post_hoc 0.833333\n", "self 0.500000" ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>prediction</th>\n", " </tr>\n", " <tr>\n", " <th>explainer</th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>post_hoc</th>\n", " <td>0.833333</td>\n", " </tr>\n", " <tr>\n", " <th>self</th>\n", " <td>0.500000</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 189, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 189 | |
| } ], | |
| "metadata" : { | |
| "kernelspec" : { | |
| "display_name" : "Python 3", | |
| "language" : "python", | |
| "name" : "python3" | |
| }, | |
| "language_info" : { | |
| "codemirror_mode" : { | |
| "name" : "ipython", | |
| "version" : 2 | |
| }, | |
| "file_extension" : ".py", | |
| "mimetype" : "text/x-python", | |
| "name" : "python", | |
| "nbconvert_exporter" : "python", | |
| "pygments_lexer" : "ipython2", | |
| "version" : "2.7.6" | |
| } | |
| }, | |
| "nbformat" : 4, | |
| "nbformat_minor" : 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment