Created
November 25, 2025 19:16
-
-
Save ethanabrooks/29895679ecb6c045b732a971f5d5299c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "d6844df1", | |
| "metadata": {}, | |
| "source": [ | |
| "# Dataset Analysis Report\n", | |
| "\n", | |
| "This notebook analyzes pass@1 evaluation results across multiple samples per problem.\n", | |
| "The notebook is generated using jupytext, and the `config` in the next cell is auto-populated." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "3a88dd6b", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:14:51.107486Z", | |
| "iopub.status.busy": "2025-11-25T19:14:51.107338Z", | |
| "iopub.status.idle": "2025-11-25T19:15:04.905939Z", | |
| "shell.execute_reply": "2025-11-25T19:15:04.905319Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/ethan/.cache/pants/named_caches/pex_root/venvs/3/bf7d981e9fd29e8ce739c11346899ca46e3c84b4/4978de77e09000b88dd0d51c4ca60c416127ff6b/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "2025-11-25 19:14:55,267\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "WARNING:absl:Failed to import TraceAnnotation.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/ethan/.cache/pants/named_caches/pex_root/venvs/3/bf7d981e9fd29e8ce739c11346899ca46e3c84b4/4978de77e09000b88dd0d51c4ca60c416127ff6b/lib/python3.12/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'default' attribute with value 'name' was provided to the `Field()` function, which has no effect in the context it was used. 'default' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", | |
| " warnings.warn(\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "\n", | |
| "import os\n", | |
| "\n", | |
| "import matplotlib\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "import seaborn as sns\n", | |
| "import wandb\n", | |
| "from IPython.display import (\n", | |
| " display, # pyright: ignore[reportUnknownVariableType]\n", | |
| ")\n", | |
| "\n", | |
| "from olympus.projects.minos.scripts.best_of_n_analysis import (\n", | |
| " automation,\n", | |
| " dummy_config,\n", | |
| " metrics,\n", | |
| " models,\n", | |
| ")\n", | |
| "from olympus.storage.t2 import spanner as spanner_trace_storage\n", | |
| "\n", | |
| "matplotlib.use('Agg')\n", | |
| "\n", | |
| "# NOTE:\n", | |
| "# - This cell is the *injection point*.\n", | |
| "# - A generator script will overwrite the body of this cell with a concrete\n", | |
| "# `models.ExperimentConfig(...)` assignment.\n", | |
| "# - The dummy config import above ensures type checking works before injection." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "18538e95", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:04.908315Z", | |
| "iopub.status.busy": "2025-11-25T19:15:04.907689Z", | |
| "iopub.status.idle": "2025-11-25T19:15:04.911133Z", | |
| "shell.execute_reply": "2025-11-25T19:15:04.910559Z" | |
| }, | |
| "tags": [ | |
| "minos-config" | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from olympus.projects.minos.scripts.best_of_n_analysis import models\n", | |
| "\n", | |
| "config = models.ExperimentConfig(\n", | |
| " collection='/minos/bon_single/guru-math__deepscaler_preview_11_0.8_1000-test-with-tools/qwen/51653a9eedfb',\n", | |
| " group_size=3,\n", | |
| " rl_job_url='https://graphein.reflectionai.dev/?page=1&pageSize=25&nameFilter=curious-port&experiment_name=ethan-curious-port-11-07',\n", | |
| " rl_train_collection='/swe-bench/online-rl/ethan-curious-port-11-07-train',\n", | |
| " rl_test_collection='/swe-bench/online-rl/ethan-curious-port-11-07-test',\n", | |
| " wandb_run_url='https://wandb.ai/reflectionai/tools/runs/01bdg4ux',\n", | |
| ")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "c41a6faf", | |
| "metadata": {}, | |
| "source": [ | |
| "## Configuration\n", | |
| "\n", | |
| "Configuration is injected in the cell above." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "ca2c04aa", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:04.912437Z", | |
| "iopub.status.busy": "2025-11-25T19:15:04.912214Z", | |
| "iopub.status.idle": "2025-11-25T19:15:04.934552Z", | |
| "shell.execute_reply": "2025-11-25T19:15:04.934081Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loaded stats with 3002 rows\n", | |
| "Columns: ['trace/prompt_len', 'trace/response_len', 'trace/num_messages', 'trace/prompt', 'trace/sample', 'trace/index', 'sequence', 'sequence/score', 'sequence/instance_id']\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Load statistics DataFrame from collection\n", | |
| "# This reads sequences from the collection and computes trace statistics\n", | |
| "read_sequences = metrics.read_sequences(\n", | |
| " config.collection,\n", | |
| " 'qwen3', # formatter name\n", | |
| " chunk_size=100,\n", | |
| " limit=None, # Set to a number to limit sequences for debugging\n", | |
| " use_cache=True, # Uses file-based caching for faster subsequent runs\n", | |
| ")\n", | |
| "stats = await read_sequences\n", | |
| "\n", | |
| "print(f'Loaded stats with {len(stats)} rows')\n", | |
| "print(f'Columns: {list(stats.columns)}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "eb3e68b5", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:04.935979Z", | |
| "iopub.status.busy": "2025-11-25T19:15:04.935827Z", | |
| "iopub.status.idle": "2025-11-25T19:15:05.864959Z", | |
| "shell.execute_reply": "2025-11-25T19:15:05.864273Z" | |
| }, | |
| "tags": [ | |
| "rl-scores" | |
| ] | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:projects/reflectionai/instances/dataprism/databases/production:Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "RL train collection: /swe-bench/online-rl/ethan-curious-port-11-07-train\n", | |
| "RL test collection: /swe-bench/online-rl/ethan-curious-port-11-07-test\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:projects/reflectionai/instances/dataprism/databases/production:Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Retrieved 51 checkpoint scores from /swe-bench/online-rl/ethan-curious-port-11-07-train\n", | |
| "First score: checkpoint_step=0, average_score=0.140\n", | |
| "Last score: checkpoint_step=1250, average_score=0.774\n", | |
| "RL start accuracy: 0.14\n", | |
| "RL end accuracy: 0.77\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "RL completion length: 4875.7 tokens\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Get RL training scores from trace storage and compute accuracy metrics\n", | |
| "# Extract experiment name from RL job URL\n", | |
| "experiment_name = automation.extract_experiment_name_from_url(config.rl_job_url)\n", | |
| "\n", | |
| "# Get launch config for the experiment\n", | |
| "launch_config_obj = automation.get_launch_config(experiment_name)\n", | |
| "\n", | |
| "# Extract train and test collection paths from launch config\n", | |
| "# Assert they exist since they're required\n", | |
| "assert config.rl_train_collection is not None, 'rl_train_collection is required'\n", | |
| "assert config.rl_test_collection is not None, 'rl_test_collection is required'\n", | |
| "\n", | |
| "print(f'RL train collection: {config.rl_train_collection}')\n", | |
| "print(f'RL test collection: {config.rl_test_collection}')\n", | |
| "\n", | |
| "# Get scores from trace storage (using train collection)\n", | |
| "trace_storage = spanner_trace_storage.create_production_spanner_trace_storage()\n", | |
| "scores = trace_storage.get_average_score_per_checkpoint_step(\n", | |
| " config.rl_train_collection\n", | |
| ")\n", | |
| "\n", | |
| "print(\n", | |
| " f'Retrieved {len(scores)} checkpoint scores from {config.rl_train_collection}'\n", | |
| ")\n", | |
| "if scores:\n", | |
| " start_score = scores[0]\n", | |
| " end_score = scores[-1]\n", | |
| " rl_start_accuracy = f'{start_score.average_score:.2f}'\n", | |
| " rl_end_accuracy = f'{end_score.average_score:.2f}'\n", | |
| "\n", | |
| " print(\n", | |
| " f'First score: checkpoint_step={start_score.checkpoint_step}, average_score={start_score.average_score:.3f}'\n", | |
| " )\n", | |
| " print(\n", | |
| " f'Last score: checkpoint_step={end_score.checkpoint_step}, average_score={end_score.average_score:.3f}'\n", | |
| " )\n", | |
| " print(f'RL start accuracy: {rl_start_accuracy}')\n", | |
| " print(f'RL end accuracy: {rl_end_accuracy}')\n", | |
| "else:\n", | |
| " print('Warning: No scores found')\n", | |
| " rl_start_accuracy = None\n", | |
| " rl_end_accuracy = None\n", | |
| "\n", | |
| "# Extract completion length from WandB if URL provided\n", | |
| "rl_completion_length: float | None = None\n", | |
| "if config.wandb_run_url:\n", | |
| " experiment_name_wandb, run_id = automation.parse_wandb_run_url(\n", | |
| " config.wandb_run_url\n", | |
| " )\n", | |
| "\n", | |
| " api_key = os.environ.get('WANDB_API_KEY')\n", | |
| " if api_key:\n", | |
| " wandb.login(key=api_key)\n", | |
| "\n", | |
| " api = wandb.Api(timeout=60)\n", | |
| " run = api.run(f'reflectionai/{experiment_name_wandb}/{run_id}') # pyright: ignore[reportUnknownVariableType]\n", | |
| "\n", | |
| " metric_name = 'steps/rendered/completion_length/mean'\n", | |
| " summary = run.summary # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]\n", | |
| " if metric_name in summary:\n", | |
| " value = summary[metric_name] # pyright: ignore[reportAssignmentType]\n", | |
| " if isinstance(value, (int, float)):\n", | |
| " rl_completion_length = float(value)\n", | |
| " print(f'RL completion length: {rl_completion_length:.1f} tokens')\n", | |
| " else:\n", | |
| " print(\n", | |
| " f'Warning: Metric {metric_name} found but value is not numeric: {value}'\n", | |
| " )\n", | |
| " else:\n", | |
| " print(f'Warning: Metric {metric_name} not found in WandB run summary')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "6bb3d4bc", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:05.866322Z", | |
| "iopub.status.busy": "2025-11-25T19:15:05.866142Z", | |
| "iopub.status.idle": "2025-11-25T19:15:05.881279Z", | |
| "shell.execute_reply": "2025-11-25T19:15:05.880780Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>trace/prompt_len</th>\n", | |
| " <th>trace/response_len</th>\n", | |
| " <th>trace/num_messages</th>\n", | |
| " <th>trace/prompt</th>\n", | |
| " <th>trace/sample</th>\n", | |
| " <th>trace/index</th>\n", | |
| " <th>sequence</th>\n", | |
| " <th>sequence/score</th>\n", | |
| " <th>sequence/instance_id</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>52</td>\n", | |
| " <td>493</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>9e02cd8b-2ed9-46cf-abb8-3ea37ea4205a</td>\n", | |
| " <td>1.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-10847</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>50</td>\n", | |
| " <td>290</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>43a06035-8822-4c04-ac4f-2267f5fa13c8</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-10231</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>72</td>\n", | |
| " <td>787</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>41cb0867-bea6-48d4-b942-7100f3b0e2a5</td>\n", | |
| " <td>1.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-126</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>61</td>\n", | |
| " <td>501</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>9401fbb9-2b38-4e26-8c60-89144c8612c8</td>\n", | |
| " <td>1.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-10261</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>67</td>\n", | |
| " <td>685</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>2e0bf4df-662c-41c9-8038-8ad897c7ca89</td>\n", | |
| " <td>1.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-10398</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>...</th>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " <td>...</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2997</th>\n", | |
| " <td>117</td>\n", | |
| " <td>1212</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>2</td>\n", | |
| " <td>dc6131c9-da0d-4707-a99a-21d9f9d7a48a</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-4758</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2998</th>\n", | |
| " <td>114</td>\n", | |
| " <td>9981</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>2</td>\n", | |
| " <td>23567486-fb1e-4570-b3b8-c8539db53713</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-9999</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2999</th>\n", | |
| " <td>151</td>\n", | |
| " <td>1271</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>2</td>\n", | |
| " <td>cf5eab86-8260-4b35-ae66-047e2701b300</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-5489</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3000</th>\n", | |
| " <td>94</td>\n", | |
| " <td>1521</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>1</td>\n", | |
| " <td>e8300fee-08a9-40dc-b7ba-01c366840e56</td>\n", | |
| " <td>1.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-53022</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3001</th>\n", | |
| " <td>94</td>\n", | |
| " <td>22537</td>\n", | |
| " <td>3</td>\n", | |
| " <td><|im_start|>system\\n<|im_end|>\\n<|im_start|>us...</td>\n", | |
| " <td><|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n...</td>\n", | |
| " <td>2</td>\n", | |
| " <td>1714230c-7a30-487b-abc9-3d0da2ebb9bf</td>\n", | |
| " <td>0.0</td>\n", | |
| " <td>guru-math__deepscaler_preview-53022</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>3002 rows × 9 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " trace/prompt_len trace/response_len trace/num_messages \\\n", | |
| "0 52 493 3 \n", | |
| "1 50 290 3 \n", | |
| "2 72 787 3 \n", | |
| "3 61 501 3 \n", | |
| "4 67 685 3 \n", | |
| "... ... ... ... \n", | |
| "2997 117 1212 3 \n", | |
| "2998 114 9981 3 \n", | |
| "2999 151 1271 3 \n", | |
| "3000 94 1521 3 \n", | |
| "3001 94 22537 3 \n", | |
| "\n", | |
| " trace/prompt \\\n", | |
| "0 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "1 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "2 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "3 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "4 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "... ... \n", | |
| "2997 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "2998 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "2999 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "3000 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "3001 <|im_start|>system\\n<|im_end|>\\n<|im_start|>us... \n", | |
| "\n", | |
| " trace/sample trace/index \\\n", | |
| "0 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 0 \n", | |
| "1 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 0 \n", | |
| "2 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 0 \n", | |
| "3 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 0 \n", | |
| "4 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 0 \n", | |
| "... ... ... \n", | |
| "2997 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 2 \n", | |
| "2998 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 2 \n", | |
| "2999 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 2 \n", | |
| "3000 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 1 \n", | |
| "3001 <|im_start|>assistant\\n<think>\\n\\n</think>\\n\\n... 2 \n", | |
| "\n", | |
| " sequence sequence/score \\\n", | |
| "0 9e02cd8b-2ed9-46cf-abb8-3ea37ea4205a 1.0 \n", | |
| "1 43a06035-8822-4c04-ac4f-2267f5fa13c8 0.0 \n", | |
| "2 41cb0867-bea6-48d4-b942-7100f3b0e2a5 1.0 \n", | |
| "3 9401fbb9-2b38-4e26-8c60-89144c8612c8 1.0 \n", | |
| "4 2e0bf4df-662c-41c9-8038-8ad897c7ca89 1.0 \n", | |
| "... ... ... \n", | |
| "2997 dc6131c9-da0d-4707-a99a-21d9f9d7a48a 0.0 \n", | |
| "2998 23567486-fb1e-4570-b3b8-c8539db53713 0.0 \n", | |
| "2999 cf5eab86-8260-4b35-ae66-047e2701b300 0.0 \n", | |
| "3000 e8300fee-08a9-40dc-b7ba-01c366840e56 1.0 \n", | |
| "3001 1714230c-7a30-487b-abc9-3d0da2ebb9bf 0.0 \n", | |
| "\n", | |
| " sequence/instance_id \n", | |
| "0 guru-math__deepscaler_preview-10847 \n", | |
| "1 guru-math__deepscaler_preview-10231 \n", | |
| "2 guru-math__deepscaler_preview-126 \n", | |
| "3 guru-math__deepscaler_preview-10261 \n", | |
| "4 guru-math__deepscaler_preview-10398 \n", | |
| "... ... \n", | |
| "2997 guru-math__deepscaler_preview-4758 \n", | |
| "2998 guru-math__deepscaler_preview-9999 \n", | |
| "2999 guru-math__deepscaler_preview-5489 \n", | |
| "3000 guru-math__deepscaler_preview-53022 \n", | |
| "3001 guru-math__deepscaler_preview-53022 \n", | |
| "\n", | |
| "[3002 rows x 9 columns]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Overall statistics\n", | |
| "num_instances = len(stats['sequence/instance_id'].unique()) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]\n", | |
| "num_traces = len(stats)\n", | |
| "score_mean = float(stats['sequence/score'].mean())\n", | |
| "score_std = (\n", | |
| " float(stats['sequence/score'].std(ddof=1)) if num_traces > 1 else 0.0\n", | |
| ")\n", | |
| "score_sem = score_std / np.sqrt(num_traces) if num_traces else float('nan')\n", | |
| "\n", | |
| "prompt_len = stats['trace/prompt_len']\n", | |
| "response_len = stats['trace/response_len']\n", | |
| "prompt_median = float(prompt_len.median())\n", | |
| "prompt_p95 = float(prompt_len.quantile(0.95)) # pyright: ignore[reportUnknownMemberType]\n", | |
| "response_median = float(response_len.median())\n", | |
| "response_p95 = float(response_len.quantile(0.95)) # pyright: ignore[reportUnknownMemberType]\n", | |
| "\n", | |
| "display(stats)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "d79f2d60", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:05.882457Z", | |
| "iopub.status.busy": "2025-11-25T19:15:05.882312Z", | |
| "iopub.status.idle": "2025-11-25T19:15:05.888918Z", | |
| "shell.execute_reply": "2025-11-25T19:15:05.888460Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>Outcome</th>\n", | |
| " <th>Count</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>Successes</td>\n", | |
| " <td>1559</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>Failures</td>\n", | |
| " <td>1443</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>Success Rate</td>\n", | |
| " <td>51.9%</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>Failure Rate</td>\n", | |
| " <td>48.1%</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " Outcome Count\n", | |
| "0 Successes 1559\n", | |
| "1 Failures 1443\n", | |
| "2 Success Rate 51.9%\n", | |
| "3 Failure Rate 48.1%" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Outcome breakdown\n", | |
| "total = len(stats)\n", | |
| "successes = stats[stats['sequence/score'] == 1.0]\n", | |
| "failures = stats[stats['sequence/score'] == 0.0]\n", | |
| "success_rate = len(successes) / total * 100 if total else 0.0\n", | |
| "failure_rate = len(failures) / total * 100 if total else 0.0\n", | |
| "\n", | |
| "num_successes = len(successes)\n", | |
| "num_failures = len(failures)\n", | |
| "\n", | |
| "outcomes_df = pd.DataFrame({\n", | |
| " 'Outcome': ['Successes', 'Failures', 'Success Rate', 'Failure Rate'],\n", | |
| " 'Count': [\n", | |
| " num_successes,\n", | |
| " num_failures,\n", | |
| " f'{success_rate:.1f}%',\n", | |
| " f'{failure_rate:.1f}%',\n", | |
| " ],\n", | |
| "})\n", | |
| "display(outcomes_df)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "9a443e0f", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:05.890058Z", | |
| "iopub.status.busy": "2025-11-25T19:15:05.889910Z", | |
| "iopub.status.idle": "2025-11-25T19:15:05.902531Z", | |
| "shell.execute_reply": "2025-11-25T19:15:05.902048Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>Rank</th>\n", | |
| " <th>Pass@N</th>\n", | |
| " <th>Pass Rate</th>\n", | |
| " <th>Incremental Lift</th>\n", | |
| " <th>Successes</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>54.0%</td>\n", | |
| " <td>54.0%</td>\n", | |
| " <td>540</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>2</td>\n", | |
| " <td>65.6%</td>\n", | |
| " <td>11.6%</td>\n", | |
| " <td>656</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>2</td>\n", | |
| " <td>3</td>\n", | |
| " <td>72.5%</td>\n", | |
| " <td>6.9%</td>\n", | |
| " <td>725</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " Rank Pass@N Pass Rate Incremental Lift Successes\n", | |
| "0 0 1 54.0% 54.0% 540\n", | |
| "1 1 2 65.6% 11.6% 656\n", | |
| "2 2 3 72.5% 6.9% 725" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Rank pass rates\n", | |
| "rank_rates: list[models.RankPassRate] = []\n", | |
| "total_instances = stats['sequence/instance_id'].nunique()\n", | |
| "for rank in range(config.group_size):\n", | |
| " subset = stats[stats['trace/index'] <= rank]\n", | |
| " if subset.empty or total_instances == 0:\n", | |
| " pass_rate = float('nan')\n", | |
| " incremental_lift = float('nan')\n", | |
| " successes_count = 0\n", | |
| " else:\n", | |
| " best_by_instance = subset.groupby('sequence/instance_id')[ # pyright: ignore[reportUnknownMemberType]\n", | |
| " 'sequence/score'\n", | |
| " ].max()\n", | |
| " successes_count = int(best_by_instance.sum())\n", | |
| " pass_rate = successes_count / total_instances\n", | |
| " incremental_lift = pass_rate - (\n", | |
| " rank_rates[-1].pass_rate if rank_rates else 0.0\n", | |
| " )\n", | |
| "\n", | |
| " rank_rates.append(\n", | |
| " models.RankPassRate(\n", | |
| " rank=rank,\n", | |
| " pass_rate=pass_rate,\n", | |
| " incremental_lift=incremental_lift,\n", | |
| " successes=successes_count,\n", | |
| " )\n", | |
| " )\n", | |
| "\n", | |
| "rank_rates_df = pd.DataFrame({\n", | |
| " 'Rank': [rate.rank for rate in rank_rates],\n", | |
| " 'Pass@N': [rate.rank + 1 for rate in rank_rates],\n", | |
| " 'Pass Rate': [f'{rate.pass_rate:.1%}' for rate in rank_rates],\n", | |
| " 'Incremental Lift': [f'{rate.incremental_lift:.1%}' for rate in rank_rates],\n", | |
| " 'Successes': [rate.successes for rate in rank_rates],\n", | |
| "})\n", | |
| "display(rank_rates_df)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "37146cad", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:05.903640Z", | |
| "iopub.status.busy": "2025-11-25T19:15:05.903494Z", | |
| "iopub.status.idle": "2025-11-25T19:15:05.909889Z", | |
| "shell.execute_reply": "2025-11-25T19:15:05.909416Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>Successes</th>\n", | |
| " <th>Count</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0</td>\n", | |
| " <td>275</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>193</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>2</td>\n", | |
| " <td>231</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>3</td>\n", | |
| " <td>300</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " Successes Count\n", | |
| "0 0 275\n", | |
| "1 1 193\n", | |
| "2 2 231\n", | |
| "3 3 300" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Success buckets (distribution of successes per instance)\n", | |
| "successes_per_instance = (\n", | |
| " stats.groupby('sequence/instance_id')['sequence/score'] # pyright: ignore[reportUnknownMemberType]\n", | |
| " .sum()\n", | |
| " .astype(int)\n", | |
| ")\n", | |
| "success_buckets = [\n", | |
| " models.SuccessBucket(\n", | |
| " successes=successes,\n", | |
| " count=int((successes_per_instance == successes).sum()),\n", | |
| " )\n", | |
| " for successes in range(config.group_size + 1)\n", | |
| "]\n", | |
| "\n", | |
| "success_buckets_df = pd.DataFrame({\n", | |
| " 'Successes': [bucket.successes for bucket in success_buckets],\n", | |
| " 'Count': [bucket.count for bucket in success_buckets],\n", | |
| "})\n", | |
| "display(success_buckets_df)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "a16237f3", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:05.911007Z", | |
| "iopub.status.busy": "2025-11-25T19:15:05.910845Z", | |
| "iopub.status.idle": "2025-11-25T19:15:05.954189Z", | |
| "shell.execute_reply": "2025-11-25T19:15:05.953700Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 600x400 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Pass@N progression\n", | |
| "\n", | |
| "pass_points = [(rate.rank + 1, rate.pass_rate) for rate in rank_rates]\n", | |
| "if pass_points:\n", | |
| " fig, ax = plt.subplots(figsize=(6, 4)) # pyright: ignore[reportUnknownMemberType]\n", | |
| " ns = [n for n, _ in pass_points]\n", | |
| " vals = [v * 100 for _, v in pass_points]\n", | |
| " ax.plot(ns, vals, marker='o', color='#1f77b4', linewidth=2) # pyright: ignore[reportUnknownMemberType]\n", | |
| " for n, pct in zip(ns, vals, strict=True):\n", | |
| " ax.text(n, pct + 0.5, f'{pct:.1f}%', ha='center', va='bottom') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_title('Pass@N progression') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_xlabel('N (samples per problem)') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_ylabel('Success rate (%)') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_ylim(0, 100)\n", | |
| " ax.grid(alpha=0.3) # pyright: ignore[reportUnknownMemberType]\n", | |
| " fig.tight_layout()\n", | |
| " display(fig)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "8c333ca0", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:05.955365Z", | |
| "iopub.status.busy": "2025-11-25T19:15:05.955215Z", | |
| "iopub.status.idle": "2025-11-25T19:15:06.052163Z", | |
| "shell.execute_reply": "2025-11-25T19:15:06.051660Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 1200x500 with 2 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Token length distributions\n", | |
| "\n", | |
| "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # pyright: ignore[reportUnknownMemberType]\n", | |
| "prompt_lengths = stats['trace/prompt_len']\n", | |
| "response_lengths = stats['trace/response_len']\n", | |
| "\n", | |
| "ax1.hist(\n", | |
| " prompt_lengths,\n", | |
| " bins=30,\n", | |
| " color='#1f77b4',\n", | |
| " alpha=0.75,\n", | |
| " edgecolor='black',\n", | |
| ")\n", | |
| "ax1.axvline(float(prompt_len.mean()), color='red', linestyle='--', label='mean')\n", | |
| "ax1.axvline(prompt_p95, color='purple', linestyle=':', label='p95')\n", | |
| "ax1.set_title('Prompt length distribution')\n", | |
| "ax1.set_xlabel('Tokens')\n", | |
| "ax1.set_ylabel('Frequency')\n", | |
| "ax1.legend()\n", | |
| "ax1.grid(alpha=0.3, axis='y')\n", | |
| "\n", | |
| "ax2.hist(\n", | |
| " response_lengths,\n", | |
| " bins=30,\n", | |
| " color='#ff7f0e',\n", | |
| " alpha=0.75,\n", | |
| " edgecolor='black',\n", | |
| ")\n", | |
| "ax2.axvline(\n", | |
| " float(response_len.mean()), color='red', linestyle='--', label='mean'\n", | |
| ")\n", | |
| "ax2.axvline(response_p95, color='purple', linestyle=':', label='p95')\n", | |
| "ax2.set_title('Response length distribution')\n", | |
| "ax2.set_xlabel('Tokens')\n", | |
| "ax2.legend()\n", | |
| "ax2.grid(alpha=0.3, axis='y')\n", | |
| "\n", | |
| "fig.tight_layout()\n", | |
| "display(fig)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "dd2074f4", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:06.053354Z", | |
| "iopub.status.busy": "2025-11-25T19:15:06.053193Z", | |
| "iopub.status.idle": "2025-11-25T19:15:06.087266Z", | |
| "shell.execute_reply": "2025-11-25T19:15:06.086764Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 600x400 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Success distribution per sequence\n", | |
| "\n", | |
| "if success_buckets:\n", | |
| " fig, ax = plt.subplots(figsize=(6, 4)) # pyright: ignore[reportUnknownMemberType]\n", | |
| " successes = [bucket.successes for bucket in success_buckets]\n", | |
| " counts = [bucket.count for bucket in success_buckets]\n", | |
| " ax.bar(successes, counts, color='#9467bd') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_xticks(successes) # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_xlabel(f'Number of successful samples (out of {config.group_size})') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_ylabel('Sequences') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_title('Distribution of successes per sequence') # pyright: ignore[reportUnknownMemberType]\n", | |
| " for s, c in zip(successes, counts, strict=True):\n", | |
| " ax.text(s, c, f'{c}', ha='center', va='bottom') # pyright: ignore[reportUnknownMemberType]\n", | |
| " fig.tight_layout()\n", | |
| " display(fig)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "eb31674c", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:06.088428Z", | |
| "iopub.status.busy": "2025-11-25T19:15:06.088280Z", | |
| "iopub.status.idle": "2025-11-25T19:15:06.635830Z", | |
| "shell.execute_reply": "2025-11-25T19:15:06.635244Z" | |
| }, | |
| "tags": [ | |
| "rl-accuracy-curves" | |
| ] | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:projects/reflectionai/instances/dataprism/databases/production:Created multiplexed session.\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 800x500 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# RL training progress\n", | |
| "\n", | |
| "# Get RL training scores from trace storage\n", | |
| "# Assert collections exist since they're required\n", | |
| "assert config.rl_train_collection is not None, 'rl_train_collection is required'\n", | |
| "assert config.rl_test_collection is not None, 'rl_test_collection is required'\n", | |
| "\n", | |
| "trace_storage = spanner_trace_storage.create_production_spanner_trace_storage()\n", | |
| "train_scores = trace_storage.get_average_score_per_checkpoint_step(\n", | |
| " config.rl_train_collection\n", | |
| ")\n", | |
| "test_scores = trace_storage.get_average_score_per_checkpoint_step(\n", | |
| " config.rl_test_collection\n", | |
| ")\n", | |
| "\n", | |
| "if train_scores and test_scores:\n", | |
| " fig, ax = plt.subplots(figsize=(8, 5)) # pyright: ignore[reportUnknownMemberType]\n", | |
| "\n", | |
| " train_steps = [s.checkpoint_step for s in train_scores]\n", | |
| " train_accs = [s.average_score * 100 for s in train_scores]\n", | |
| " test_steps = [s.checkpoint_step for s in test_scores]\n", | |
| " test_accs = [s.average_score * 100 for s in test_scores]\n", | |
| "\n", | |
| " sns.lineplot(\n", | |
| " x=train_steps,\n", | |
| " y=train_accs,\n", | |
| " marker='o',\n", | |
| " label='Train',\n", | |
| " ax=ax,\n", | |
| " linewidth=2,\n", | |
| " )\n", | |
| " sns.lineplot(\n", | |
| " x=test_steps,\n", | |
| " y=test_accs,\n", | |
| " marker='s',\n", | |
| " label='Test',\n", | |
| " ax=ax,\n", | |
| " linewidth=2,\n", | |
| " )\n", | |
| "\n", | |
| " ax.set_title('RL Training Progress: Train vs Test Accuracy') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_xlabel('Checkpoint Step') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_ylabel('Accuracy (%)') # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.set_ylim(0, 100)\n", | |
| " ax.legend() # pyright: ignore[reportUnknownMemberType]\n", | |
| " ax.grid(alpha=0.3) # pyright: ignore[reportUnknownMemberType]\n", | |
| "\n", | |
| " fig.tight_layout()\n", | |
| " display(fig)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "fe8cdcf4", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-11-25T19:15:06.637228Z", | |
| "iopub.status.busy": "2025-11-25T19:15:06.637059Z", | |
| "iopub.status.idle": "2025-11-25T19:15:06.640875Z", | |
| "shell.execute_reply": "2025-11-25T19:15:06.640374Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n", | |
| "=== Pass@N Progression ===\n", | |
| "Pass@1: 54.0% (incremental lift: 54.0%)\n", | |
| "Pass@2: 65.6% (incremental lift: 11.6%)\n", | |
| "Pass@3: 72.5% (incremental lift: 6.9%)\n", | |
| "\n", | |
| "=== Success Distribution ===\n", | |
| "0 successes: 275 sequences\n", | |
| "1 successes: 193 sequences\n", | |
| "2 successes: 231 sequences\n", | |
| "3 successes: 300 sequences\n", | |
| "\n", | |
| "=== Token Length Statistics ===\n", | |
| "Prompt - Mean: 112.5, P95: 215.0\n", | |
| "Response - Mean: 1738.8, P95: 3965.4\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Display key results\n", | |
| "print('\\n=== Pass@N Progression ===')\n", | |
| "for rate in rank_rates:\n", | |
| " print(\n", | |
| " f'Pass@{rate.rank + 1}: {rate.pass_rate:.1%} (incremental lift: {rate.incremental_lift:.1%})'\n", | |
| " )\n", | |
| "\n", | |
| "print('\\n=== Success Distribution ===')\n", | |
| "for bucket in success_buckets:\n", | |
| " print(f'{bucket.successes} successes: {bucket.count} sequences')\n", | |
| "\n", | |
| "print('\\n=== Token Length Statistics ===')\n", | |
| "print(f'Prompt - Mean: {float(prompt_len.mean()):.1f}, P95: {prompt_p95:.1f}')\n", | |
| "print(\n", | |
| " f'Response - Mean: {float(response_len.mean()):.1f}, P95: {response_p95:.1f}'\n", | |
| ")" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "jupytext": { | |
| "cell_metadata_filter": "tags,-all", | |
| "main_language": "python", | |
| "notebook_metadata_filter": "-all" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment