Created
December 4, 2025 15:56
-
-
Save tteofili/ec3e2d1dddf0b1ccc0507336458427bc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells" : [ { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:39:41.213606Z", | |
| "start_time" : "2025-12-04T15:39:41.212183Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : "", | |
| "id" : "e635173ef70d8762", | |
| "outputs" : [ ], | |
| "execution_count" : null | |
| }, { | |
| "metadata" : { | |
| "collapsed" : true, | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:27:55.135540Z", | |
| "start_time" : "2025-12-04T15:27:52.008915Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "import pandas as pd\n", "import json\n", "import operator\n", "import re\n", "\n", "from ellmer.selfexplainer import SelfExplainer" ], | |
| "id" : "909306fc9640edd9", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "/Users/tteofili/dev/ellmer/ellmer/selfexplainer.py:535: SyntaxWarning: \"is not\" with a literal. Did you mean \"!=\"?\n", " if nm is not None and len(ns) is not 0 and len(cf) is not 0:\n", "/Users/tteofili/dev/ellmer/ellmer/selfexplainer.py:535: SyntaxWarning: \"is not\" with a literal. Did you mean \"!=\"?\n", " if nm is not None and len(ns) is not 0 and len(cf) is not 0:\n" ] | |
| } ], | |
| "execution_count" : 3 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:27:55.142210Z", | |
| "start_time" : "2025-12-04T15:27:55.138584Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "def token_perturb_for_target_label(row, ranked_cols, predict_fn, label, k):\n", "\n", " row = row.copy()\n", " top_k = ranked_cols[:k]\n", "\n", " for entry in top_k:\n", " if \"__\" not in entry:\n", " raise ValueError(f\"Ranked column entry must be 'attr__token', got: {entry}\")\n", "\n", " attr, token = entry.split(\"__\", 1)\n", "\n", " if attr not in row:\n", " continue\n", "\n", " original_val = str(row[attr])\n", "\n", " counterpart_attr = None\n", " if attr.startswith(\"ltable_\"):\n", " counterpart_attr = \"rtable_\" + attr[len(\"ltable_\"):]\n", " elif attr.startswith(\"rtable_\"):\n", " counterpart_attr = \"ltable_\" + attr[len(\"rtable_\"):]\n", "\n", " if label == 1:\n", " # remove the token from the attribute\n", " pattern = re.escape(token)\n", " new_val = re.sub(pattern, \"\", original_val, flags=re.IGNORECASE)\n", "\n", " # clean double spaces left behind\n", " new_val = re.sub(r\"\\s+\", \" \", new_val).strip()\n", " row[attr] = new_val\n", "\n", " else:\n", " # make token equal to the counterpart’s value or its matching token\n", " if counterpart_attr in row:\n", " counterpart_val = str(row[counterpart_attr])\n", "\n", " # if counterpart contains the same token, force-copy only that token\n", " if re.search(re.escape(token), counterpart_val, flags=re.IGNORECASE):\n", " replacement = token\n", " else:\n", " # replace token with the *whole counterpart attribute* (fallback)\n", " replacement = counterpart_val\n", "\n", " pattern = re.escape(token)\n", " new_val = re.sub(pattern, replacement, original_val, flags=re.IGNORECASE)\n", "\n", " row[attr] = new_val\n", "\n", " new_pred = predict_fn(row)\n", " return row, new_pred\n" ], | |
| "id" : "428d878712482b1", | |
| "outputs" : [ ], | |
| "execution_count" : 4 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:27:57.383217Z", | |
| "start_time" : "2025-12-04T15:27:57.224350Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "llm = SelfExplainer(explanation_granularity='attribute',\n", " temperature=1,\n", " model_name='gpt-5-nano', model_type='azure_openai',\n", " prompts={\"ptse\": {\"er\": \"../ellmer/prompts/er.txt\"}})" ], | |
| "id" : "b56f71faca9a505a", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "/Users/tteofili/dev/ellmer/ellmer/selfexplainer.py:38: LangChainDeprecationWarning: The class `AzureChatOpenAI` was deprecated in LangChain 0.0.10 and will be removed in 1.0. An updated version of the class exists in the :class:`~langchain-openai package and should be used instead. To use it run `pip install -U :class:`~langchain-openai` and import as `from :class:`~langchain_openai import AzureChatOpenAI``.\n", " self.llm = AzureChatOpenAI(model_name=model_name, request_timeout=120,\n" ] | |
| } ], | |
| "execution_count" : 5 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:28:00.185812Z", | |
| "start_time" : "2025-12-04T15:28:00.181468Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "def predict_fn(row):\n", " return llm.predict(pd.DataFrame(row))['match_score'].values[0]\n" ], | |
| "id" : "a710f7acb402396a", | |
| "outputs" : [ ], | |
| "execution_count" : 6 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:22.982092Z", | |
| "start_time" : "2025-12-04T15:37:22.976149Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "with open(\"/Users/tteofili/dev/ellmer/experiments/azure_openai/gpt-5-nano/token/beers/20251204/14_51/0_results.json\") as f:\n", " d = json.load(f)" ], | |
| "id" : "20da3c2e2a410071", | |
| "outputs" : [ ], | |
| "execution_count" : 46 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:24.204944Z", | |
| "start_time" : "2025-12-04T15:37:24.200170Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "ltuple = pd.json_normalize(d['data'][0]['cot_sample']['ltuple']).add_prefix('ltable_')\n", "rtuple = pd.json_normalize(d['data'][0]['cot_sample']['rtuple']).add_prefix('rtable_')" ], | |
| "id" : "e6b98d295d5adc27", | |
| "outputs" : [ ], | |
| "execution_count" : 47 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:25.377641Z", | |
| "start_time" : "2025-12-04T15:37:25.369509Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "beers_row = pd.concat([ltuple, rtuple], axis=1)\n", "beers_row" ], | |
| "id" : "3bba24732497f80f", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_Beer_Name ltable_Brew_Factory_Name \\\n", "0 Bulleit Bourbon Barrel Aged G'Knight Oskar Blues Grill & Brew \n", "\n", " ltable_Style ltable_ABV \\\n", "0 American Amber / Red Ale 8.70 % \n", "\n", " rtable_Beer_Name rtable_Brew_Factory_Name \\\n", "0 Figure Eight Bourbon Barrel Aged Jumbo Love Figure Eight Brewing \n", "\n", " rtable_Style rtable_ABV \n", "0 Barley Wine - " ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>ltable_Beer_Name</th>\n", " <th>ltable_Brew_Factory_Name</th>\n", " <th>ltable_Style</th>\n", " <th>ltable_ABV</th>\n", " <th>rtable_Beer_Name</th>\n", " <th>rtable_Brew_Factory_Name</th>\n", " <th>rtable_Style</th>\n", " <th>rtable_ABV</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Bulleit Bourbon Barrel Aged G'Knight</td>\n", " <td>Oskar Blues Grill & Brew</td>\n", " <td>American Amber / Red Ale</td>\n", " <td>8.70 %</td>\n", " <td>Figure Eight Bourbon Barrel Aged Jumbo Love</td>\n", " <td>Figure Eight Brewing</td>\n", " <td>Barley Wine</td>\n", " <td>-</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 48, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 48 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:26.279169Z", | |
| "start_time" : "2025-12-04T15:37:26.271760Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "prediction = pd.json_normalize(d['data'][0]['cot_sample'])['prediction']\n", "prediction" ], | |
| "id" : "2565ccf869a697d5", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ "0 0\n", "Name: prediction, dtype: int64" ] | |
| }, | |
| "execution_count" : 49, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 49 | |
| }, { | |
| "metadata" : { }, | |
| "cell_type" : "code", | |
| "outputs" : [ ], | |
| "execution_count" : null, | |
| "source" : "", | |
| "id" : "9fe431c5ad2e4300" | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:27.858339Z", | |
| "start_time" : "2025-12-04T15:37:27.845997Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "self_explanation = pd.json_normalize(d['data'][0]['cot_sample']['saliency'])\n", "self_explanation" ], | |
| "id" : "bcd247cd99207821", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_Beer_Name__Bulleit ltable_Beer_Name__Bourbon \\\n", "0 0.7 0.25 \n", "\n", " ltable_Beer_Name__Barrel ltable_Beer_Name__Aged \\\n", "0 0.2 0.15 \n", "\n", " ltable_Beer_Name__G'Knight ltable_Brew_Factory_Name__Oskar \\\n", "0 0.65 0.05 \n", "\n", " ltable_Brew_Factory_Name__Blues ltable_Brew_Factory_Name__Grill \\\n", "0 0.05 0.05 \n", "\n", " ltable_Brew_Factory_Name__& ltable_Brew_Factory_Name__Brew ... \\\n", "0 0.05 0.05 ... \n", "\n", " rtable_Beer_Name__Barrel rtable_Beer_Name__Aged rtable_Beer_Name__Jumbo \\\n", "0 0.2 0.18 0.76 \n", "\n", " rtable_Beer_Name__Love rtable_Brew_Factory_Name__Figure \\\n", "0 0.64 0.5 \n", "\n", " rtable_Brew_Factory_Name__Eight rtable_Brew_Factory_Name__Brewing \\\n", "0 0.5 0.6 \n", "\n", " rtable_Style__Barley rtable_Style__Wine rtable_ABV__- \n", "0 0.6 0.55 0.1 \n", "\n", "[1 rows x 29 columns]" ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>ltable_Beer_Name__Bulleit</th>\n", " <th>ltable_Beer_Name__Bourbon</th>\n", " <th>ltable_Beer_Name__Barrel</th>\n", " <th>ltable_Beer_Name__Aged</th>\n", " <th>ltable_Beer_Name__G'Knight</th>\n", " <th>ltable_Brew_Factory_Name__Oskar</th>\n", " <th>ltable_Brew_Factory_Name__Blues</th>\n", " <th>ltable_Brew_Factory_Name__Grill</th>\n", " <th>ltable_Brew_Factory_Name__&</th>\n", " <th>ltable_Brew_Factory_Name__Brew</th>\n", " <th>...</th>\n", " <th>rtable_Beer_Name__Barrel</th>\n", " <th>rtable_Beer_Name__Aged</th>\n", " <th>rtable_Beer_Name__Jumbo</th>\n", " <th>rtable_Beer_Name__Love</th>\n", " <th>rtable_Brew_Factory_Name__Figure</th>\n", " <th>rtable_Brew_Factory_Name__Eight</th>\n", " <th>rtable_Brew_Factory_Name__Brewing</th>\n", " <th>rtable_Style__Barley</th>\n", " <th>rtable_Style__Wine</th>\n", " <th>rtable_ABV__-</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.7</td>\n", " <td>0.25</td>\n", " <td>0.2</td>\n", " <td>0.15</td>\n", " <td>0.65</td>\n", " <td>0.05</td>\n", " <td>0.05</td>\n", " <td>0.05</td>\n", " <td>0.05</td>\n", " <td>0.05</td>\n", " <td>...</td>\n", " <td>0.2</td>\n", " <td>0.18</td>\n", " <td>0.76</td>\n", " <td>0.64</td>\n", " <td>0.5</td>\n", " <td>0.5</td>\n", " <td>0.6</td>\n", " <td>0.6</td>\n", " <td>0.55</td>\n", " <td>0.1</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>1 rows × 29 columns</p>\n", "</div>" ] | |
| }, | |
| "execution_count" : 50, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 50 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:29.066595Z", | |
| "start_time" : "2025-12-04T15:37:29.056915Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "post_hoc_explanation = pd.json_normalize(d['data'][0]['certa_sample']['saliency'])\n", "post_hoc_explanation" ], | |
| "id" : "ea0e7d561021f797", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_Brew_Factory_Name__Blues ltable_Beer_Name__G'Knight \\\n", "0 0.590909 0.636364 \n", "\n", " ltable_ABV__8.70 ltable_Beer_Name__Bulleit ltable_Brew_Factory_Name__& \\\n", "0 0.636364 0.681818 0.681818 \n", "\n", " ltable_Brew_Factory_Name__Oskar ltable_Style__/ ltable_Style__Amber \\\n", "0 0.681818 0.681818 0.681818 \n", "\n", " ltable_Beer_Name__Aged ltable_Beer_Name__Barrel ... \\\n", "0 0.545455 0.590909 ... \n", "\n", " rtable_Beer_Name__Eight rtable_Beer_Name__Figure rtable_Style__Barley \\\n", "0 0.318182 0.318182 0.318182 \n", "\n", " rtable_Beer_Name__Love rtable_Brew_Factory_Name__Figure \\\n", "0 0.272727 0.318182 \n", "\n", " rtable_Beer_Name__Bourbon rtable_Beer_Name__Barrel \\\n", "0 0.318182 0.318182 \n", "\n", " rtable_Brew_Factory_Name__Brewing ltable_Style__American ltable_ABV__% \n", "0 0.318182 0.590909 0.590909 \n", "\n", "[1 rows x 29 columns]" ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>ltable_Brew_Factory_Name__Blues</th>\n", " <th>ltable_Beer_Name__G'Knight</th>\n", " <th>ltable_ABV__8.70</th>\n", " <th>ltable_Beer_Name__Bulleit</th>\n", " <th>ltable_Brew_Factory_Name__&</th>\n", " <th>ltable_Brew_Factory_Name__Oskar</th>\n", " <th>ltable_Style__/</th>\n", " <th>ltable_Style__Amber</th>\n", " <th>ltable_Beer_Name__Aged</th>\n", " <th>ltable_Beer_Name__Barrel</th>\n", " <th>...</th>\n", " <th>rtable_Beer_Name__Eight</th>\n", " <th>rtable_Beer_Name__Figure</th>\n", " <th>rtable_Style__Barley</th>\n", " <th>rtable_Beer_Name__Love</th>\n", " <th>rtable_Brew_Factory_Name__Figure</th>\n", " <th>rtable_Beer_Name__Bourbon</th>\n", " <th>rtable_Beer_Name__Barrel</th>\n", " <th>rtable_Brew_Factory_Name__Brewing</th>\n", " <th>ltable_Style__American</th>\n", " <th>ltable_ABV__%</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.590909</td>\n", " <td>0.636364</td>\n", " <td>0.636364</td>\n", " <td>0.681818</td>\n", " <td>0.681818</td>\n", " <td>0.681818</td>\n", " <td>0.681818</td>\n", " <td>0.681818</td>\n", " <td>0.545455</td>\n", " <td>0.590909</td>\n", " <td>...</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.272727</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.590909</td>\n", " <td>0.590909</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>1 rows × 29 columns</p>\n", "</div>" ] | |
| }, | |
| "execution_count" : 51, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 51 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:37:32.401691Z", | |
| "start_time" : "2025-12-04T15:37:32.390848Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "# compare\n", "explanations = pd.concat([self_explanation, post_hoc_explanation], axis=0, keys=['self', 'post_hoc'])\n", "explanations.head()" ], | |
| "id" : "92275ff5b2d6fc4d", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " ltable_Beer_Name__Bulleit ltable_Beer_Name__Bourbon \\\n", "self 0 0.700000 0.250000 \n", "post_hoc 0 0.681818 0.681818 \n", "\n", " ltable_Beer_Name__Barrel ltable_Beer_Name__Aged \\\n", "self 0 0.200000 0.150000 \n", "post_hoc 0 0.590909 0.545455 \n", "\n", " ltable_Beer_Name__G'Knight ltable_Brew_Factory_Name__Oskar \\\n", "self 0 0.650000 0.050000 \n", "post_hoc 0 0.636364 0.681818 \n", "\n", " ltable_Brew_Factory_Name__Blues ltable_Brew_Factory_Name__Grill \\\n", "self 0 0.050000 0.05 \n", "post_hoc 0 0.590909 0.50 \n", "\n", " ltable_Brew_Factory_Name__& ltable_Brew_Factory_Name__Brew ... \\\n", "self 0 0.050000 0.05 ... \n", "post_hoc 0 0.681818 NaN ... \n", "\n", " rtable_Beer_Name__Aged rtable_Beer_Name__Jumbo \\\n", "self 0 0.180000 0.760000 \n", "post_hoc 0 0.318182 0.272727 \n", "\n", " rtable_Beer_Name__Love rtable_Brew_Factory_Name__Figure \\\n", "self 0 0.640000 0.500000 \n", "post_hoc 0 0.272727 0.318182 \n", "\n", " rtable_Brew_Factory_Name__Eight \\\n", "self 0 0.500000 \n", "post_hoc 0 0.318182 \n", "\n", " rtable_Brew_Factory_Name__Brewing rtable_Style__Barley \\\n", "self 0 0.600000 0.600000 \n", "post_hoc 0 0.318182 0.318182 \n", "\n", " rtable_Style__Wine rtable_ABV__- ltable_Style__/ \n", "self 0 0.550000 0.100000 NaN \n", "post_hoc 0 0.318182 0.318182 0.681818 \n", "\n", "[2 rows x 30 columns]" ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th></th>\n", " <th>ltable_Beer_Name__Bulleit</th>\n", " <th>ltable_Beer_Name__Bourbon</th>\n", " <th>ltable_Beer_Name__Barrel</th>\n", " <th>ltable_Beer_Name__Aged</th>\n", " <th>ltable_Beer_Name__G'Knight</th>\n", " <th>ltable_Brew_Factory_Name__Oskar</th>\n", " <th>ltable_Brew_Factory_Name__Blues</th>\n", " <th>ltable_Brew_Factory_Name__Grill</th>\n", " <th>ltable_Brew_Factory_Name__&</th>\n", " <th>ltable_Brew_Factory_Name__Brew</th>\n", " <th>...</th>\n", " <th>rtable_Beer_Name__Aged</th>\n", " <th>rtable_Beer_Name__Jumbo</th>\n", " <th>rtable_Beer_Name__Love</th>\n", " <th>rtable_Brew_Factory_Name__Figure</th>\n", " <th>rtable_Brew_Factory_Name__Eight</th>\n", " <th>rtable_Brew_Factory_Name__Brewing</th>\n", " <th>rtable_Style__Barley</th>\n", " <th>rtable_Style__Wine</th>\n", " <th>rtable_ABV__-</th>\n", " <th>ltable_Style__/</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>self</th>\n", " <th>0</th>\n", " <td>0.700000</td>\n", " <td>0.250000</td>\n", " <td>0.200000</td>\n", " <td>0.150000</td>\n", " <td>0.650000</td>\n", " <td>0.050000</td>\n", " <td>0.050000</td>\n", " <td>0.05</td>\n", " <td>0.050000</td>\n", " <td>0.05</td>\n", " <td>...</td>\n", " <td>0.180000</td>\n", " <td>0.760000</td>\n", " <td>0.640000</td>\n", " <td>0.500000</td>\n", " <td>0.500000</td>\n", " <td>0.600000</td>\n", " <td>0.600000</td>\n", " <td>0.550000</td>\n", " <td>0.100000</td>\n", " <td>NaN</td>\n", " </tr>\n", " <tr>\n", " <th>post_hoc</th>\n", " <th>0</th>\n", " <td>0.681818</td>\n", " <td>0.681818</td>\n", " <td>0.590909</td>\n", " <td>0.545455</td>\n", " <td>0.636364</td>\n", " <td>0.681818</td>\n", " <td>0.590909</td>\n", " <td>0.50</td>\n", " <td>0.681818</td>\n", " <td>NaN</td>\n", " <td>...</td>\n", " <td>0.318182</td>\n", " <td>0.272727</td>\n", " <td>0.272727</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.318182</td>\n", " <td>0.681818</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>2 rows × 30 columns</p>\n", "</div>" ] | |
| }, | |
| "execution_count" : 52, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 52 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:55:30.655705Z", | |
| "start_time" : "2025-12-04T15:55:30.503424Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "explanations_viz = explanations.copy()\n", "explanations_viz.columns = explanations.columns.str.split('__').str[-1]\n", "explanations_viz.T.plot(kind='bar', stacked=False, figsize=(12, 6), rot=58)" ], | |
| "id" : "fcf69878d4338a74", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ "<Axes: >" ] | |
| }, | |
| "execution_count" : 60, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| }, { | |
| "data" : { | |
| "text/plain" : [ "<Figure size 1200x600 with 1 Axes>" ], | |
| "image/png" : "" | |
| }, | |
| "metadata" : { }, | |
| "output_type" : "display_data", | |
| "jetTransient" : { | |
| "display_id" : null | |
| } | |
| } ], | |
| "execution_count" : 60 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:38:03.892640Z", | |
| "start_time" : "2025-12-04T15:38:01.520122Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "expl1 = [r[0] for r in sorted(explanations.iloc[0].to_dict().items(), key=operator.itemgetter(1), reverse=True)]\n", "token_perturb_for_target_label(beers_row, expl1, predict_fn, label=int(prediction.values), k=5)" ], | |
| "id" : "2af6f53ef7514183", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/2637309926.py:2: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " token_perturb_for_target_label(beers_row, expl1, predict_fn, label=int(prediction.values), k=5)\n", " 0%| | 0/1 [00:00<?, ?it/s]/Users/tteofili/dev/ellmer/ellmer/selfexplainer.py:80: LangChainDeprecationWarning: The method `BaseChatModel.__call__` was deprecated in langchain-core 0.1.7 and will be removed in 1.0. Use :meth:`~invoke` instead.\n", " answer = self.llm(messages)\n", "100%|██████████| 1/1 [00:02<00:00, 2.36s/it]\n" ] | |
| }, { | |
| "data" : { | |
| "text/plain" : [ "( ltable_Beer_Name \\\n", " 0 0 0 Bulleit Bourbon Barrel Aged 0 0 ... \n", " \n", " ltable_Brew_Factory_Name ltable_Style ltable_ABV \\\n", " 0 Oskar Blues Grill & Brew American Amber / Red Ale 8.70 % \n", " \n", " rtable_Beer_Name rtable_Brew_Factory_Name \\\n", " 0 0 0 0 0 Bulleit Bourbon Barrel Age... Figure Eight Brewing \n", " \n", " rtable_Style rtable_ABV \n", " 0 Barley Wine - ,\n", " 0)" ] | |
| }, | |
| "execution_count" : 55, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 55 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:38:58.474991Z", | |
| "start_time" : "2025-12-04T15:38:16.737600Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : [ "k_range = 6\n", "res = []\n", "\n", "for i in range(len(explanations)):\n", " expl = explanations.iloc[i]\n", " explainer = expl.name[0]\n", " salient_features = [r[0] for r in sorted(expl.to_dict().items(), key=operator.itemgetter(1), reverse=True)]\n", " for k in range(k_range):\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", " cr = {'prediction': pred, 'perturbed': pert, 'explainer': explainer}\n", " res.append(cr)\n", "\n", "consistency = pd.DataFrame(res)\n" ], | |
| "id" : "3afed76d80884a44", | |
| "outputs" : [ { | |
| "name" : "stderr", | |
| "output_type" : "stream", | |
| "text" : [ "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:02<00:00, 2.11s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:04<00:00, 4.02s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:02<00:00, 2.20s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:02<00:00, 2.67s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:03<00:00, 3.72s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:02<00:00, 2.18s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:01<00:00, 1.35s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:02<00:00, 2.94s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:04<00:00, 4.62s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:07<00:00, 7.05s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:05<00:00, 5.73s/it]\n", "/var/folders/mr/6xnd0hrs6257283btx8ff88r0000gn/T/ipykernel_9919/1240791563.py:7: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", " pert, pred = token_perturb_for_target_label(beers_row, salient_features, predict_fn, label=int(prediction.values), k=k)\n", "100%|██████████| 1/1 [00:03<00:00, 3.10s/it]\n" ] | |
| } ], | |
| "execution_count" : 56 | |
| }, { | |
| "metadata" : { | |
| "ExecuteTime" : { | |
| "end_time" : "2025-12-04T15:38:58.513983Z", | |
| "start_time" : "2025-12-04T15:38:58.503571Z" | |
| } | |
| }, | |
| "cell_type" : "code", | |
| "source" : "consistency[['explainer','prediction']].groupby(['explainer']).sum()/k_range", | |
| "id" : "98f451ed7e6bee27", | |
| "outputs" : [ { | |
| "data" : { | |
| "text/plain" : [ " prediction\n", "explainer \n", "post_hoc 0.333333\n", "self 0.000000" ], | |
| "text/html" : [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>prediction</th>\n", " </tr>\n", " <tr>\n", " <th>explainer</th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>post_hoc</th>\n", " <td>0.333333</td>\n", " </tr>\n", " <tr>\n", " <th>self</th>\n", " <td>0.000000</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ] | |
| }, | |
| "execution_count" : 57, | |
| "metadata" : { }, | |
| "output_type" : "execute_result" | |
| } ], | |
| "execution_count" : 57 | |
| } ], | |
| "metadata" : { | |
| "kernelspec" : { | |
| "display_name" : "Python 3", | |
| "language" : "python", | |
| "name" : "python3" | |
| }, | |
| "language_info" : { | |
| "codemirror_mode" : { | |
| "name" : "ipython", | |
| "version" : 2 | |
| }, | |
| "file_extension" : ".py", | |
| "mimetype" : "text/x-python", | |
| "name" : "python", | |
| "nbconvert_exporter" : "python", | |
| "pygments_lexer" : "ipython2", | |
| "version" : "2.7.6" | |
| } | |
| }, | |
| "nbformat" : 4, | |
| "nbformat_minor" : 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment