Last active
January 18, 2026 02:42
-
-
Save rileyseaburg/9ef91a35862c403228ab04a4e6438a0c to your computer and use it in GitHub Desktop.
Fine-tune Qwen3-4B for Recursive Language Models (RLM) - Infinite Context
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Fine-tuning Qwen3-4B for Recursive Language Model (RLM)\n", | |
| "\n", | |
| "Based on: **\"Recursive Language Models\"** (Zhang et al., 2025) - arXiv:2512.24601\n", | |
| "\n", | |
| "This notebook fine-tunes Qwen3-4B-Instruct-2507 with:\n", | |
| "1. **Coding datasets** - Build strong code generation skills\n", | |
| "2. **RLM trajectories** - Learn REPL interaction patterns\n", | |
| "\n", | |
| "**Run in Google Colab with T4/A100 GPU**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Install dependencies\n", | |
| "!pip install -q torch transformers accelerate peft trl bitsandbytes datasets huggingface_hub" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "print(f\"PyTorch: {torch.__version__}\")\n", | |
| "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", | |
| "if torch.cuda.is_available():\n", | |
| " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", | |
| " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1. RLM System Prompt (from Paper Appendix D.1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "RLM_SYSTEM_PROMPT = '''You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs.\n", | |
| "\n", | |
| "Your context is a {context_type} with {context_total_length} total characters.\n", | |
| "\n", | |
| "The REPL environment is initialized with:\n", | |
| "1. A 'context' variable containing the input data\n", | |
| "2. A 'llm_query' function to recursively call sub-LLMs\n", | |
| "3. The ability to use 'print()' to view outputs\n", | |
| "\n", | |
| "IMPORTANT: Batch llm_query calls efficiently (~200k chars per call).\n", | |
| "\n", | |
| "Write Python code in ```repl blocks. When done, use:\n", | |
| "- FINAL(your answer) for direct answers\n", | |
| "- FINAL_VAR(variable_name) to return a variable\n", | |
| "\n", | |
| "Think step by step and execute your plan immediately.'''\n", | |
| "\n", | |
| "print(f\"System prompt: {len(RLM_SYSTEM_PROMPT)} chars\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 2. Load Coding Datasets\n", | |
| "\n", | |
| "We combine multiple coding datasets:\n", | |
| "- **Evol-Instruct-Code** - WizardCoder's synthetic code instructions\n", | |
| "- **CodeSearchNet** - Functions with docstrings\n", | |
| "- **MBPP** - Python coding problems\n", | |
| "- **Custom RLM trajectories** - REPL interaction patterns" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from datasets import load_dataset, concatenate_datasets, Dataset\n", | |
| "import random\n", | |
| "import json\n", | |
| "\n", | |
| "# ============================================================================\n", | |
| "# 2.1 Evol-Instruct-Code (WizardCoder style)\n", | |
| "# ============================================================================\n", | |
| "print(\"Loading Evol-Instruct-Code...\")\n", | |
| "try:\n", | |
| " evol_code = load_dataset(\"WizardLMTeam/WizardLM_evol_instruct_V2_196k\", split=\"train\")\n", | |
| " # Filter for code-related\n", | |
| " code_keywords = ['code', 'function', 'python', 'program', 'algorithm', 'implement', 'write a', 'def ', 'class ']\n", | |
| " evol_code = evol_code.filter(\n", | |
| " lambda x: any(kw in x['conversations'][0]['value'].lower() for kw in code_keywords)\n", | |
| " )\n", | |
| " evol_code = evol_code.shuffle(seed=42).select(range(min(5000, len(evol_code))))\n", | |
| " print(f\" Evol-Instruct-Code: {len(evol_code)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Could not load Evol-Instruct: {e}\")\n", | |
| " evol_code = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ============================================================================\n", | |
| "# 2.2 CodeSearchNet (Python functions with docstrings)\n", | |
| "# ============================================================================\n", | |
| "print(\"Loading CodeSearchNet...\")\n", | |
| "try:\n", | |
| " codesearch = load_dataset(\"code_search_net\", \"python\", split=\"train\", trust_remote_code=True)\n", | |
| " codesearch = codesearch.shuffle(seed=42).select(range(min(5000, len(codesearch))))\n", | |
| " print(f\" CodeSearchNet: {len(codesearch)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Could not load CodeSearchNet: {e}\")\n", | |
| " codesearch = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ============================================================================\n", | |
| "# 2.3 MBPP (Mostly Basic Python Problems)\n", | |
| "# ============================================================================\n", | |
| "print(\"Loading MBPP...\")\n", | |
| "try:\n", | |
| " mbpp = load_dataset(\"mbpp\", split=\"train\", trust_remote_code=True)\n", | |
| " print(f\" MBPP: {len(mbpp)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Could not load MBPP: {e}\")\n", | |
| " mbpp = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ============================================================================\n", | |
| "# 2.4 OpenHermes Code subset (high quality)\n", | |
| "# ============================================================================\n", | |
| "print(\"Loading OpenHermes...\")\n", | |
| "try:\n", | |
| " hermes = load_dataset(\"teknium/OpenHermes-2.5\", split=\"train\")\n", | |
| " # Filter for code\n", | |
| " hermes = hermes.filter(\n", | |
| " lambda x: any(kw in str(x.get('conversations', [{}])[0].get('value', '')).lower() \n", | |
| " for kw in ['```python', '```', 'def ', 'function', 'code'])\n", | |
| " )\n", | |
| " hermes = hermes.shuffle(seed=42).select(range(min(3000, len(hermes))))\n", | |
| " print(f\" OpenHermes (code): {len(hermes)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Could not load OpenHermes: {e}\")\n", | |
| " hermes = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 3. Generate RLM Training Examples\n", | |
| "\n", | |
| "Create synthetic RLM trajectories that teach the model to:\n", | |
| "- Examine context with code\n", | |
| "- Use `llm_query()` for recursive calls\n", | |
| "- Output `FINAL()` / `FINAL_VAR()`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def generate_rlm_needle_example():\n", | |
| " \"\"\"Needle-in-haystack: Find hidden value in large context.\"\"\"\n", | |
| " needle_value = random.randint(1000, 9999)\n", | |
| " patterns = [\n", | |
| " (f\"SECRET_CODE={needle_value}\", \"SECRET_CODE\"),\n", | |
| " (f\"MAGIC_NUMBER: {needle_value}\", \"MAGIC_NUMBER\"),\n", | |
| " (f\"hidden_key = {needle_value}\", \"hidden_key\"),\n", | |
| " (f\"# PASSWORD: {needle_value}\", \"PASSWORD\"),\n", | |
| " ]\n", | |
| " needle, name = random.choice(patterns)\n", | |
| " \n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=\"string\", context_total_length=\"~100,000\"\n", | |
| " )},\n", | |
| " {\"role\": \"user\", \"content\": f\"Find the {name} value in this document.\"},\n", | |
| " {\"role\": \"assistant\", \"content\": f'''Let me search the context for {name}.\n", | |
| "\n", | |
| "```repl\n", | |
| "print(f\"Context length: {{len(context)}} chars\")\n", | |
| "print(f\"Preview: {{context[:200]}}\")\n", | |
| "```\n", | |
| "\n", | |
| "Now searching for the pattern:\n", | |
| "\n", | |
| "```repl\n", | |
| "import re\n", | |
| "pattern = r\"{name}[=:\\\\s]+([\\\\d]+)\"\n", | |
| "match = re.search(pattern, context)\n", | |
| "if match:\n", | |
| " result = match.group(1)\n", | |
| " print(f\"Found: {{result}}\")\n", | |
| "else:\n", | |
| " # Search in chunks\n", | |
| " for i in range(0, len(context), 20000):\n", | |
| " chunk = context[i:i+20000]\n", | |
| " if \"{name}\" in chunk:\n", | |
| " result = llm_query(f\"Extract the {name} value from: {{chunk}}\")\n", | |
| " print(f\"Found in chunk {{i//20000}}: {{result}}\")\n", | |
| " break\n", | |
| "```\n", | |
| "\n", | |
| "FINAL({needle_value})'''}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "\n", | |
| "def generate_rlm_aggregation_example():\n", | |
| " \"\"\"Aggregation: Count/classify items across large dataset.\"\"\"\n", | |
| " categories = [\"error\", \"warning\", \"info\", \"debug\"]\n", | |
| " target = random.choice(categories)\n", | |
| " \n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=\"log file\", context_total_length=\"~500,000\"\n", | |
| " )},\n", | |
| " {\"role\": \"user\", \"content\": f\"Count all '{target}' level log entries.\"},\n", | |
| " {\"role\": \"assistant\", \"content\": f'''I'll count {target} entries in the logs.\n", | |
| "\n", | |
| "```repl\n", | |
| "lines = context.strip().split('\\\\n')\n", | |
| "print(f\"Total lines: {{len(lines)}}\")\n", | |
| "print(f\"Sample: {{lines[:3]}}\")\n", | |
| "```\n", | |
| "\n", | |
| "Now counting:\n", | |
| "\n", | |
| "```repl\n", | |
| "count = 0\n", | |
| "for line in lines:\n", | |
| " if '[{target.upper()}]' in line or '{target}' in line.lower():\n", | |
| " count += 1\n", | |
| "\n", | |
| "print(f\"Total {target} entries: {{count}}\")\n", | |
| "result = count\n", | |
| "```\n", | |
| "\n", | |
| "FINAL_VAR(result)'''}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "\n", | |
| "def generate_rlm_code_search_example():\n", | |
| " \"\"\"Code search: Find functions/classes in codebase.\"\"\"\n", | |
| " targets = [\n", | |
| " (\"authentication\", [\"login\", \"logout\", \"auth\", \"verify_token\"]),\n", | |
| " (\"database\", [\"query\", \"connect\", \"execute\", \"fetch\"]),\n", | |
| " (\"validation\", [\"validate\", \"check\", \"verify\", \"sanitize\"]),\n", | |
| " (\"API endpoints\", [\"@app.route\", \"@router\", \"def get_\", \"def post_\"]),\n", | |
| " ]\n", | |
| " topic, keywords = random.choice(targets)\n", | |
| " \n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=\"list of code files\", context_total_length=\"~2,000,000\"\n", | |
| " )},\n", | |
| " {\"role\": \"user\", \"content\": f\"Find all functions related to {topic} and explain what each does.\"},\n", | |
| " {\"role\": \"assistant\", \"content\": f'''I'll search the codebase for {topic}-related functions.\n", | |
| "\n", | |
| "```repl\n", | |
| "print(f\"Files in codebase: {{len(context)}}\")\n", | |
| "print(f\"File structure: {{type(context[0]) if context else 'empty'}}\")\n", | |
| "```\n", | |
| "\n", | |
| "Searching for relevant patterns:\n", | |
| "\n", | |
| "```repl\n", | |
| "import re\n", | |
| "keywords = {keywords}\n", | |
| "found_functions = []\n", | |
| "\n", | |
| "for file_info in context:\n", | |
| " filename = file_info.get('filename', '')\n", | |
| " content = file_info.get('content', '')\n", | |
| " \n", | |
| " # Find function definitions\n", | |
| " for match in re.finditer(r'def\\\\s+(\\\\w+)\\\\s*\\\\([^)]*\\\\):', content):\n", | |
| " func_name = match.group(1)\n", | |
| " if any(kw in func_name.lower() for kw in keywords):\n", | |
| " # Get function body (next 10 lines)\n", | |
| " start = match.start()\n", | |
| " snippet = content[start:start+500]\n", | |
| " found_functions.append({{\n", | |
| " 'name': func_name,\n", | |
| " 'file': filename,\n", | |
| " 'code': snippet\n", | |
| " }})\n", | |
| "\n", | |
| "print(f\"Found {{len(found_functions)}} functions\")\n", | |
| "for f in found_functions[:5]:\n", | |
| " print(f\"- {{f['name']}} in {{f['file']}}\")\n", | |
| "```\n", | |
| "\n", | |
| "Now let me explain each function:\n", | |
| "\n", | |
| "```repl\n", | |
| "explanations = []\n", | |
| "for func in found_functions:\n", | |
| " explanation = llm_query(f\"\"\"Explain this function in 1 sentence:\n", | |
| "File: {{func['file']}}\n", | |
| "{{func['code']}}\"\"\")\n", | |
| " explanations.append(f\"- {{func['name']}}: {{explanation}}\")\n", | |
| "\n", | |
| "result = '\\\\n'.join(explanations)\n", | |
| "print(result)\n", | |
| "```\n", | |
| "\n", | |
| "FINAL_VAR(result)'''}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "\n", | |
| "def generate_rlm_multi_doc_example():\n", | |
| " \"\"\"Multi-document QA: Answer questions spanning many documents.\"\"\"\n", | |
| " questions = [\n", | |
| " \"Which documents discuss both performance optimization and security?\",\n", | |
| " \"Find all documents that reference the main API and summarize their purpose.\",\n", | |
| " \"List documents containing TODO comments and what needs to be done.\",\n", | |
| " \"Which files import the 'utils' module and what do they use it for?\",\n", | |
| " ]\n", | |
| " question = random.choice(questions)\n", | |
| " \n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=\"list of 500 documents\", context_total_length=\"~5,000,000\"\n", | |
| " )},\n", | |
| " {\"role\": \"user\", \"content\": question},\n", | |
| " {\"role\": \"assistant\", \"content\": f'''I'll search through all documents to answer this.\n", | |
| "\n", | |
| "```repl\n", | |
| "print(f\"Total documents: {{len(context)}}\")\n", | |
| "print(f\"Sample doc keys: {{context[0].keys() if context else 'none'}}\")\n", | |
| "```\n", | |
| "\n", | |
| "Searching documents:\n", | |
| "\n", | |
| "```repl\n", | |
| "matching_docs = []\n", | |
| "\n", | |
| "# Process in batches to use llm_query efficiently\n", | |
| "batch_size = 20\n", | |
| "for i in range(0, len(context), batch_size):\n", | |
| " batch = context[i:i+batch_size]\n", | |
| " batch_text = '\\\\n---\\\\n'.join(\n", | |
| " f\"FILE: {{d.get('filename', f'doc_{{j}}')}}}\\\\n{{d.get('content', '')[:2000]}}\"\n", | |
| " for j, d in enumerate(batch)\n", | |
| " )\n", | |
| " \n", | |
| " result = llm_query(f\"\"\"{question}\n", | |
| " \n", | |
| "Documents:\n", | |
| "{{batch_text}}\n", | |
| "\n", | |
| "List matching filenames only, one per line, or 'NONE' if no matches.\"\"\")\n", | |
| " \n", | |
| " if result.strip() != 'NONE':\n", | |
| " matching_docs.extend(result.strip().split('\\\\n'))\n", | |
| " \n", | |
| " if (i // batch_size) % 5 == 0:\n", | |
| " print(f\"Processed {{i + batch_size}}/{{len(context)}} docs, found {{len(matching_docs)}} matches\")\n", | |
| "\n", | |
| "print(f\"\\\\nTotal matches: {{len(matching_docs)}}\")\n", | |
| "result = matching_docs\n", | |
| "```\n", | |
| "\n", | |
| "FINAL_VAR(result)'''}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "\n", | |
| "# Generate RLM examples\n", | |
| "print(\"Generating RLM training examples...\")\n", | |
| "rlm_generators = [\n", | |
| " generate_rlm_needle_example,\n", | |
| " generate_rlm_aggregation_example,\n", | |
| " generate_rlm_code_search_example,\n", | |
| " generate_rlm_multi_doc_example,\n", | |
| "]\n", | |
| "\n", | |
| "rlm_examples = []\n", | |
| "for i in range(2000): # 2000 RLM examples\n", | |
| " gen = random.choice(rlm_generators)\n", | |
| " rlm_examples.append(gen())\n", | |
| "\n", | |
| "print(f\"Generated {len(rlm_examples)} RLM examples\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 4. Format All Datasets" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def format_evol_instruct(example):\n", | |
| " \"\"\"Format WizardLM/Evol-Instruct to chat format.\"\"\"\n", | |
| " convs = example.get('conversations', [])\n", | |
| " if len(convs) < 2:\n", | |
| " return None\n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert Python programmer. Write clean, efficient code.\"},\n", | |
| " {\"role\": \"user\", \"content\": convs[0].get('value', '')},\n", | |
| " {\"role\": \"assistant\", \"content\": convs[1].get('value', '')}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "def format_codesearchnet(example):\n", | |
| " \"\"\"Format CodeSearchNet to chat format.\"\"\"\n", | |
| " docstring = example.get('func_documentation_string', '')\n", | |
| " code = example.get('func_code_string', '')\n", | |
| " if not docstring or not code:\n", | |
| " return None\n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert Python programmer.\"},\n", | |
| " {\"role\": \"user\", \"content\": f\"Write a Python function that: {docstring}\"},\n", | |
| " {\"role\": \"assistant\", \"content\": f\"```python\\n{code}\\n```\"}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "def format_mbpp(example):\n", | |
| " \"\"\"Format MBPP to chat format.\"\"\"\n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert Python programmer. Write clean, tested code.\"},\n", | |
| " {\"role\": \"user\", \"content\": example.get('text', '')},\n", | |
| " {\"role\": \"assistant\", \"content\": f\"```python\\n{example.get('code', '')}\\n```\"}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "def format_hermes(example):\n", | |
| " \"\"\"Format OpenHermes to chat format.\"\"\"\n", | |
| " convs = example.get('conversations', [])\n", | |
| " if len(convs) < 2:\n", | |
| " return None\n", | |
| " return {\n", | |
| " \"messages\": [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert programmer and helpful assistant.\"},\n", | |
| " {\"role\": \"user\", \"content\": convs[0].get('value', '')},\n", | |
| " {\"role\": \"assistant\", \"content\": convs[1].get('value', '')}\n", | |
| " ]\n", | |
| " }\n", | |
| "\n", | |
| "# Format all datasets\n", | |
| "all_examples = []\n", | |
| "\n", | |
| "if evol_code:\n", | |
| " for ex in evol_code:\n", | |
| " formatted = format_evol_instruct(ex)\n", | |
| " if formatted:\n", | |
| " all_examples.append(formatted)\n", | |
| " print(f\"Added {len([e for e in evol_code])} Evol-Instruct examples\")\n", | |
| "\n", | |
| "if codesearch:\n", | |
| " for ex in codesearch:\n", | |
| " formatted = format_codesearchnet(ex)\n", | |
| " if formatted:\n", | |
| " all_examples.append(formatted)\n", | |
| " print(f\"Added CodeSearchNet examples\")\n", | |
| "\n", | |
| "if mbpp:\n", | |
| " for ex in mbpp:\n", | |
| " formatted = format_mbpp(ex)\n", | |
| " if formatted:\n", | |
| " all_examples.append(formatted)\n", | |
| " print(f\"Added MBPP examples\")\n", | |
| "\n", | |
| "if hermes:\n", | |
| " for ex in hermes:\n", | |
| " formatted = format_hermes(ex)\n", | |
| " if formatted:\n", | |
| " all_examples.append(formatted)\n", | |
| " print(f\"Added OpenHermes examples\")\n", | |
| "\n", | |
| "# Add RLM examples\n", | |
| "all_examples.extend(rlm_examples)\n", | |
| "print(f\"Added {len(rlm_examples)} RLM examples\")\n", | |
| "\n", | |
| "# Shuffle\n", | |
| "random.shuffle(all_examples)\n", | |
| "\n", | |
| "print(f\"\\n=== Total training examples: {len(all_examples)} ===\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Create HuggingFace dataset\n", | |
| "from datasets import Dataset\n", | |
| "\n", | |
| "train_dataset = Dataset.from_list(all_examples)\n", | |
| "print(f\"Dataset size: {len(train_dataset)}\")\n", | |
| "print(f\"Sample: {train_dataset[0]['messages'][1]['content'][:200]}...\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 5. Load Model with QLoRA" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n", | |
| "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n", | |
| "import torch\n", | |
| "\n", | |
| "MODEL_NAME = \"Qwen/Qwen3-4B-Instruct-2507\"\n", | |
| "\n", | |
| "# QLoRA config - 4-bit quantization\n", | |
| "bnb_config = BitsAndBytesConfig(\n", | |
| " load_in_4bit=True,\n", | |
| " bnb_4bit_quant_type=\"nf4\",\n", | |
| " bnb_4bit_compute_dtype=torch.bfloat16,\n", | |
| " bnb_4bit_use_double_quant=True,\n", | |
| ")\n", | |
| "\n", | |
| "print(f\"Loading {MODEL_NAME}...\")\n", | |
| "\n", | |
| "# Load tokenizer\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", | |
| "if tokenizer.pad_token is None:\n", | |
| " tokenizer.pad_token = tokenizer.eos_token\n", | |
| "tokenizer.padding_side = \"right\"\n", | |
| "\n", | |
| "# Load model with 4-bit quantization\n", | |
| "model = AutoModelForCausalLM.from_pretrained(\n", | |
| " MODEL_NAME,\n", | |
| " quantization_config=bnb_config,\n", | |
| " device_map=\"auto\",\n", | |
| " trust_remote_code=True,\n", | |
| " attn_implementation=\"eager\", # For compatibility\n", | |
| ")\n", | |
| "\n", | |
| "# Prepare for training\n", | |
| "model = prepare_model_for_kbit_training(model)\n", | |
| "\n", | |
| "print(f\"Model loaded!\")\n", | |
| "print(f\" Hidden size: {model.config.hidden_size}\")\n", | |
| "print(f\" Layers: {model.config.num_hidden_layers}\")\n", | |
| "print(f\" Vocab size: {model.config.vocab_size}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# LoRA config - target all important layers\n", | |
| "lora_config = LoraConfig(\n", | |
| " r=64, # Rank (higher = more capacity)\n", | |
| " lora_alpha=128, # Alpha scaling\n", | |
| " lora_dropout=0.05,\n", | |
| " bias=\"none\",\n", | |
| " task_type=\"CAUSAL_LM\",\n", | |
| " target_modules=[ # Qwen3 modules\n", | |
| " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", # Attention\n", | |
| " \"gate_proj\", \"up_proj\", \"down_proj\" # MLP\n", | |
| " ],\n", | |
| ")\n", | |
| "\n", | |
| "# Apply LoRA\n", | |
| "model = get_peft_model(model, lora_config)\n", | |
| "model.print_trainable_parameters()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 6. Training" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from trl import SFTTrainer, SFTConfig\n", | |
| "from transformers import TrainingArguments\n", | |
| "\n", | |
| "# Format function for chat template\n", | |
| "def formatting_func(example):\n", | |
| " return tokenizer.apply_chat_template(\n", | |
| " example[\"messages\"],\n", | |
| " tokenize=False,\n", | |
| " add_generation_prompt=False\n", | |
| " )\n", | |
| "\n", | |
| "# Training config - use TrainingArguments for compatibility\n", | |
| "training_args = TrainingArguments(\n", | |
| " output_dir=\"./rlm-qwen3-4b-coder\",\n", | |
| " num_train_epochs=2,\n", | |
| " per_device_train_batch_size=1,\n", | |
| " gradient_accumulation_steps=8,\n", | |
| " gradient_checkpointing=True,\n", | |
| " learning_rate=2e-4,\n", | |
| " lr_scheduler_type=\"cosine\",\n", | |
| " warmup_ratio=0.05,\n", | |
| " logging_steps=25,\n", | |
| " save_steps=500,\n", | |
| " save_total_limit=3,\n", | |
| " bf16=True,\n", | |
| " optim=\"paged_adamw_8bit\",\n", | |
| " report_to=\"none\",\n", | |
| " remove_unused_columns=False,\n", | |
| ")\n", | |
| "\n", | |
| "MAX_SEQ_LENGTH = 4096\n", | |
| "\n", | |
| "print(\"Training config:\")\n", | |
| "print(f\" Epochs: {training_args.num_train_epochs}\")\n", | |
| "print(f\" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}\")\n", | |
| "print(f\" Learning rate: {training_args.learning_rate}\")\n", | |
| "print(f\" Max seq length: {MAX_SEQ_LENGTH}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Create trainer\n", | |
| "trainer = SFTTrainer(\n", | |
| " model=model,\n", | |
| " train_dataset=train_dataset,\n", | |
| " args=training_args,\n", | |
| " processing_class=tokenizer,\n", | |
| " formatting_func=formatting_func,\n", | |
| " max_seq_length=MAX_SEQ_LENGTH,\n", | |
| ")\n", | |
| "\n", | |
| "print(f\"\\nStarting training on {len(train_dataset)} examples...\")\n", | |
| "print(\"This will take ~2-3 hours on T4, ~45 min on A100\\n\")\n", | |
| "\n", | |
| "trainer.train()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Save the LoRA adapter\n", | |
| "trainer.save_model(\"./rlm-qwen3-4b-coder-final\")\n", | |
| "tokenizer.save_pretrained(\"./rlm-qwen3-4b-coder-final\")\n", | |
| "print(\"LoRA adapter saved!\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 7. Test the Fine-tuned Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Test 1: Pure coding task\n", | |
| "print(\"=\" * 60)\n", | |
| "print(\"TEST 1: Coding Task\")\n", | |
| "print(\"=\" * 60)\n", | |
| "\n", | |
| "messages = [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert Python programmer.\"},\n", | |
| " {\"role\": \"user\", \"content\": \"Write a function to find all prime numbers up to n using the Sieve of Eratosthenes.\"}\n", | |
| "]\n", | |
| "\n", | |
| "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", | |
| "inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| "\n", | |
| "with torch.no_grad():\n", | |
| " outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, do_sample=True)\n", | |
| "\n", | |
| "response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| "print(response)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Test 2: RLM task\n", | |
| "print(\"\\n\" + \"=\" * 60)\n", | |
| "print(\"TEST 2: RLM Task (should write REPL code)\")\n", | |
| "print(\"=\" * 60)\n", | |
| "\n", | |
| "messages = [\n", | |
| " {\"role\": \"system\", \"content\": RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=\"string\", context_total_length=\"500,000\"\n", | |
| " )},\n", | |
| " {\"role\": \"user\", \"content\": \"Find the SECRET_KEY value hidden somewhere in this document.\"}\n", | |
| "]\n", | |
| "\n", | |
| "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", | |
| "inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| "\n", | |
| "with torch.no_grad():\n", | |
| " outputs = model.generate(**inputs, max_new_tokens=1024, temperature=0.7, do_sample=True)\n", | |
| "\n", | |
| "response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| "print(response)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Test 3: Code search RLM task\n", | |
| "print(\"\\n\" + \"=\" * 60)\n", | |
| "print(\"TEST 3: Code Search RLM Task\")\n", | |
| "print(\"=\" * 60)\n", | |
| "\n", | |
| "messages = [\n", | |
| " {\"role\": \"system\", \"content\": RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=\"list of 200 Python files\", context_total_length=\"3,000,000\"\n", | |
| " )},\n", | |
| " {\"role\": \"user\", \"content\": \"Find all functions that handle user authentication and explain what each does.\"}\n", | |
| "]\n", | |
| "\n", | |
| "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", | |
| "inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| "\n", | |
| "with torch.no_grad():\n", | |
| " outputs = model.generate(**inputs, max_new_tokens=1024, temperature=0.7, do_sample=True)\n", | |
| "\n", | |
| "response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| "print(response)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 8. Merge and Export" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Merge LoRA into base model for deployment\n", | |
| "from peft import PeftModel\n", | |
| "\n", | |
| "print(\"Loading base model for merging...\")\n", | |
| "\n", | |
| "# Load base model in full precision\n", | |
| "base_model = AutoModelForCausalLM.from_pretrained(\n", | |
| " MODEL_NAME,\n", | |
| " torch_dtype=torch.bfloat16,\n", | |
| " device_map=\"auto\",\n", | |
| " trust_remote_code=True,\n", | |
| ")\n", | |
| "\n", | |
| "# Load and merge LoRA\n", | |
| "print(\"Merging LoRA weights...\")\n", | |
| "model = PeftModel.from_pretrained(base_model, \"./rlm-qwen3-4b-coder-final\")\n", | |
| "merged_model = model.merge_and_unload()\n", | |
| "\n", | |
| "# Save\n", | |
| "print(\"Saving merged model...\")\n", | |
| "merged_model.save_pretrained(\"./rlm-qwen3-4b-merged\")\n", | |
| "tokenizer.save_pretrained(\"./rlm-qwen3-4b-merged\")\n", | |
| "\n", | |
| "print(\"\\nMerged model saved to ./rlm-qwen3-4b-merged\")\n", | |
| "print(\"You can now upload this to HuggingFace or use it directly!\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Optional: Upload to HuggingFace\n", | |
| "from huggingface_hub import login, HfApi\n", | |
| "\n", | |
| "# login() # Uncomment and run to login\n", | |
| "\n", | |
| "# HF_USERNAME = \"your-username\"\n", | |
| "# merged_model.push_to_hub(f\"{HF_USERNAME}/rlm-qwen3-4b-coder\")\n", | |
| "# tokenizer.push_to_hub(f\"{HF_USERNAME}/rlm-qwen3-4b-coder\")\n", | |
| "# print(f\"Uploaded to https://huggingface.co/{HF_USERNAME}/rlm-qwen3-4b-coder\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 9. Full RLM Inference Test\n", | |
| "\n", | |
| "Now let's test the model with a **real RLM loop** - executing REPL code and handling sub-LLM calls!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import re\n", | |
| "from typing import Any, List, Tuple\n", | |
| "\n", | |
| "class REPLEnvironment:\n", | |
| " \"\"\"Python REPL for RLM execution.\"\"\"\n", | |
| " \n", | |
| " def __init__(self, context: Any, llm_query_fn):\n", | |
| " self.output_buffer = []\n", | |
| " \n", | |
| " def captured_print(*args):\n", | |
| " self.output_buffer.append(' '.join(str(a) for a in args))\n", | |
| " \n", | |
| " self.namespace = {\n", | |
| " 'context': context,\n", | |
| " 'llm_query': llm_query_fn,\n", | |
| " 'print': captured_print,\n", | |
| " 'len': len, 'range': range, 'enumerate': enumerate,\n", | |
| " 'str': str, 'int': int, 'float': float, 'list': list, 'dict': dict,\n", | |
| " 'sum': sum, 'min': min, 'max': max, 'sorted': sorted,\n", | |
| " 're': __import__('re'),\n", | |
| " 'json': __import__('json'),\n", | |
| " }\n", | |
| " \n", | |
| " def execute(self, code: str) -> Tuple[bool, str]:\n", | |
| " self.output_buffer = []\n", | |
| " try:\n", | |
| " exec(code, self.namespace)\n", | |
| " output = '\\n'.join(self.output_buffer)\n", | |
| " return True, output if output else '[OK]'\n", | |
| " except Exception as e:\n", | |
| " return False, f'Error: {e}'\n", | |
| " \n", | |
| " def get_var(self, name: str):\n", | |
| " return self.namespace.get(name)\n", | |
| "\n", | |
| "\n", | |
| "def extract_repl_code(response: str) -> List[str]:\n", | |
| " \"\"\"Extract ```repl code blocks.\"\"\"\n", | |
| " return re.findall(r'```repl\\s*\\n(.*?)```', response, re.DOTALL)\n", | |
| "\n", | |
| "\n", | |
| "def extract_final(response: str, repl: REPLEnvironment):\n", | |
| " \"\"\"Extract FINAL() or FINAL_VAR().\"\"\"\n", | |
| " var_match = re.search(r'FINAL_VAR\\((\\w+)\\)', response)\n", | |
| " if var_match:\n", | |
| " return str(repl.get_var(var_match.group(1)))\n", | |
| " \n", | |
| " final_match = re.search(r'FINAL\\(([^)]+)\\)', response)\n", | |
| " if final_match:\n", | |
| " return final_match.group(1).strip()\n", | |
| " \n", | |
| " return None\n", | |
| "\n", | |
| "\n", | |
| "print(\"RLM inference helpers loaded!\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def run_rlm_inference(\n", | |
| " model, \n", | |
| " tokenizer, \n", | |
| " context: Any, \n", | |
| " question: str,\n", | |
| " max_iterations: int = 10,\n", | |
| " verbose: bool = True\n", | |
| "):\n", | |
| " \"\"\"\n", | |
| " Run full RLM inference loop.\n", | |
| " \n", | |
| " 1. Send question to model\n", | |
| " 2. Extract and execute ```repl code\n", | |
| " 3. Feed results back to model\n", | |
| " 4. Repeat until FINAL() is produced\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " # Sub-LLM query function\n", | |
| " sub_call_count = [0]\n", | |
| " def llm_query(prompt: str) -> str:\n", | |
| " sub_call_count[0] += 1\n", | |
| " if verbose:\n", | |
| " print(f\" [Sub-LLM call #{sub_call_count[0]}]\")\n", | |
| " \n", | |
| " msgs = [{\"role\": \"user\", \"content\": prompt}]\n", | |
| " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n", | |
| " inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " out = model.generate(**inputs, max_new_tokens=512, temperature=0.7, do_sample=True)\n", | |
| " \n", | |
| " return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| " \n", | |
| " # Create REPL\n", | |
| " repl = REPLEnvironment(context, llm_query)\n", | |
| " \n", | |
| " # System prompt\n", | |
| " context_len = len(context) if isinstance(context, str) else len(str(context))\n", | |
| " system = RLM_SYSTEM_PROMPT.format(\n", | |
| " context_type=type(context).__name__,\n", | |
| " context_total_length=f\"{context_len:,}\"\n", | |
| " )\n", | |
| " \n", | |
| " # Conversation\n", | |
| " conversation = [question]\n", | |
| " \n", | |
| " for i in range(max_iterations):\n", | |
| " if verbose:\n", | |
| " print(f\"\\n{'='*50}\")\n", | |
| " print(f\"Iteration {i+1}\")\n", | |
| " print('='*50)\n", | |
| " \n", | |
| " # Generate response\n", | |
| " messages = [\n", | |
| " {\"role\": \"system\", \"content\": system},\n", | |
| " {\"role\": \"user\", \"content\": '\\n\\n'.join(conversation)}\n", | |
| " ]\n", | |
| " \n", | |
| " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", | |
| " inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " outputs = model.generate(\n", | |
| " **inputs, \n", | |
| " max_new_tokens=2048, \n", | |
| " temperature=0.7, \n", | |
| " do_sample=True\n", | |
| " )\n", | |
| " \n", | |
| " response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| " \n", | |
| " if verbose:\n", | |
| " print(f\"\\nModel response:\\n{response[:500]}{'...' if len(response) > 500 else ''}\")\n", | |
| " \n", | |
| " # Check for final answer\n", | |
| " final = extract_final(response, repl)\n", | |
| " if final:\n", | |
| " if verbose:\n", | |
| " print(f\"\\n{'='*50}\")\n", | |
| " print(f\"FINAL ANSWER: {final}\")\n", | |
| " print(f\"Sub-LLM calls: {sub_call_count[0]}\")\n", | |
| " print('='*50)\n", | |
| " return final\n", | |
| " \n", | |
| " # Execute REPL code\n", | |
| " code_blocks = extract_repl_code(response)\n", | |
| " \n", | |
| " if not code_blocks:\n", | |
| " conversation.append(response)\n", | |
| " conversation.append(\"[No REPL code found. Write ```repl code or provide FINAL(answer).]\")\n", | |
| " continue\n", | |
| " \n", | |
| " # Execute each block\n", | |
| " exec_results = []\n", | |
| " for j, code in enumerate(code_blocks):\n", | |
| " if verbose:\n", | |
| " print(f\"\\nExecuting code block {j+1}...\")\n", | |
| " success, output = repl.execute(code)\n", | |
| " exec_results.append(f\"[Block {j+1}]: {output[:500]}\")\n", | |
| " if verbose:\n", | |
| " print(f\"Output: {output[:200]}{'...' if len(output) > 200 else ''}\")\n", | |
| " \n", | |
| " conversation.append(response)\n", | |
| " conversation.append('\\n'.join(exec_results))\n", | |
| " \n", | |
| " return \"[Max iterations reached]\"\n", | |
| "\n", | |
| "\n", | |
| "print(\"RLM inference function ready!\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# TEST: Needle in Haystack\n", | |
| "# =============================================================================\n", | |
| "import random\n", | |
| "\n", | |
| "print(\"=\"*60)\n", | |
| "print(\"TEST: Needle in Haystack\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "# Generate haystack with hidden needle\n", | |
| "words = ['lorem', 'ipsum', 'dolor', 'sit', 'amet', 'code', 'function', 'data']\n", | |
| "haystack_parts = []\n", | |
| "for i in range(500):\n", | |
| " line = ' '.join(random.choices(words, k=10))\n", | |
| " haystack_parts.append(line)\n", | |
| "\n", | |
| "# Insert needle at random position\n", | |
| "needle_pos = random.randint(200, 400)\n", | |
| "secret_value = random.randint(1000, 9999)\n", | |
| "haystack_parts.insert(needle_pos, f\"SECRET_KEY={secret_value}\")\n", | |
| "\n", | |
| "test_context = '\\n'.join(haystack_parts)\n", | |
| "print(f\"Context size: {len(test_context):,} chars\")\n", | |
| "print(f\"Hidden SECRET_KEY={secret_value} at line {needle_pos}\")\n", | |
| "\n", | |
| "# Run RLM\n", | |
| "answer = run_rlm_inference(\n", | |
| " model, \n", | |
| " tokenizer,\n", | |
| " context=test_context,\n", | |
| " question=\"Find the SECRET_KEY value hidden in this document.\",\n", | |
| " verbose=True\n", | |
| ")\n", | |
| "\n", | |
| "print(f\"\\nExpected: {secret_value}\")\n", | |
| "print(f\"Got: {answer}\")\n", | |
| "print(f\"Correct: {str(secret_value) in str(answer)}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# TEST: Aggregation (Count items)\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"TEST: Aggregation\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "# Generate log-like data\n", | |
| "log_levels = ['INFO', 'WARNING', 'ERROR', 'DEBUG']\n", | |
| "log_lines = []\n", | |
| "level_counts = {level: 0 for level in log_levels}\n", | |
| "\n", | |
| "for i in range(200):\n", | |
| " level = random.choice(log_levels)\n", | |
| " level_counts[level] += 1\n", | |
| " log_lines.append(f\"[{level}] 2025-01-{random.randint(1,28):02d} Message {i}\")\n", | |
| "\n", | |
| "test_context = '\\n'.join(log_lines)\n", | |
| "target_level = 'ERROR'\n", | |
| "expected_count = level_counts[target_level]\n", | |
| "\n", | |
| "print(f\"Context: {len(log_lines)} log lines\")\n", | |
| "print(f\"Level counts: {level_counts}\")\n", | |
| "\n", | |
| "answer = run_rlm_inference(\n", | |
| " model,\n", | |
| " tokenizer,\n", | |
| " context=test_context,\n", | |
| " question=f\"Count how many {target_level} level log entries are in this log file.\",\n", | |
| " verbose=True\n", | |
| ")\n", | |
| "\n", | |
| "print(f\"\\nExpected: {expected_count}\")\n", | |
| "print(f\"Got: {answer}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# TEST: Multi-file Code Search\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"TEST: Code Search\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "# Simulate codebase\n", | |
| "fake_codebase = [\n", | |
| " {\n", | |
| " 'filename': 'auth/login.py',\n", | |
| " 'content': '''def login_user(username, password):\n", | |
| " \"\"\"Authenticate user with username and password.\"\"\"\n", | |
| " user = db.get_user(username)\n", | |
| " if user and verify_password(password, user.hash):\n", | |
| " return create_session(user)\n", | |
| " return None\n", | |
| "\n", | |
| "def logout_user(session_id):\n", | |
| " \"\"\"End user session.\"\"\"\n", | |
| " return db.delete_session(session_id)\n", | |
| "'''\n", | |
| " },\n", | |
| " {\n", | |
| " 'filename': 'auth/tokens.py', \n", | |
| " 'content': '''def verify_token(token):\n", | |
| " \"\"\"Verify JWT token validity.\"\"\"\n", | |
| " try:\n", | |
| " payload = jwt.decode(token, SECRET_KEY)\n", | |
| " return payload\n", | |
| " except jwt.InvalidTokenError:\n", | |
| " return None\n", | |
| "'''\n", | |
| " },\n", | |
| " {\n", | |
| " 'filename': 'api/users.py',\n", | |
| " 'content': '''def get_users():\n", | |
| " \"\"\"Get all users.\"\"\"\n", | |
| " return db.query(User).all()\n", | |
| "\n", | |
| "def create_user(data):\n", | |
| " \"\"\"Create new user.\"\"\"\n", | |
| " user = User(**data)\n", | |
| " db.add(user)\n", | |
| " return user\n", | |
| "'''\n", | |
| " },\n", | |
| " {\n", | |
| " 'filename': 'utils/helpers.py',\n", | |
| " 'content': '''def format_date(dt):\n", | |
| " return dt.strftime(\"%Y-%m-%d\")\n", | |
| "\n", | |
| "def validate_email(email):\n", | |
| " import re\n", | |
| " return bool(re.match(r\"[^@]+@[^@]+\\\\.[^@]+\", email))\n", | |
| "'''\n", | |
| " }\n", | |
| "]\n", | |
| "\n", | |
| "print(f\"Codebase: {len(fake_codebase)} files\")\n", | |
| "\n", | |
| "answer = run_rlm_inference(\n", | |
| " model,\n", | |
| " tokenizer,\n", | |
| " context=fake_codebase,\n", | |
| " question=\"Find all functions related to authentication and list them with their file paths.\",\n", | |
| " verbose=True\n", | |
| ")\n", | |
| "\n", | |
| "print(f\"\\nAnswer:\\n{answer}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# Interactive Test - Try Your Own!\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"INTERACTIVE: Try your own RLM query!\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "# Put your own context here\n", | |
| "my_context = \"\"\"\n", | |
| "This is a sample document. You can replace this with:\n", | |
| "- A long text file\n", | |
| "- A list of documents\n", | |
| "- Code files\n", | |
| "- Log data\n", | |
| "- Anything you want the model to analyze!\n", | |
| "\n", | |
| "The RLM will write Python code to examine this context\n", | |
| "and recursively call sub-LLMs if needed.\n", | |
| "\n", | |
| "Hidden secret: ANSWER=42\n", | |
| "\"\"\"\n", | |
| "\n", | |
| "my_question = \"What is the hidden ANSWER value?\"\n", | |
| "\n", | |
| "answer = run_rlm_inference(\n", | |
| " model,\n", | |
| " tokenizer,\n", | |
| " context=my_context,\n", | |
| " question=my_question,\n", | |
| " verbose=True\n", | |
| ")\n", | |
| "\n", | |
| "print(f\"\\nFinal Answer: {answer}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## 10. Real Benchmarks\n", | |
| "\n", | |
| "Evaluate on standard coding and long-context benchmarks:\n", | |
| "- **HumanEval**: Code generation (pass@1)\n", | |
| "- **MBPP**: Python programming problems\n", | |
| "- **S-NIAH (RULER)**: Single needle-in-a-haystack\n", | |
| "- **RULER Multi-Key**: Multiple key retrieval" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# BENCHMARK 1: HumanEval (Code Generation)\n", | |
| "# =============================================================================\n", | |
| "print(\"=\"*60)\n", | |
| "print(\"BENCHMARK: HumanEval (subset)\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "# Sample HumanEval problems\n", | |
| "humaneval_problems = [\n", | |
| " {\n", | |
| " \"task_id\": \"HumanEval/0\",\n", | |
| " \"prompt\": \"from typing import List\\n\\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\\n \\\"\\\"\\\" Check if in given list of numbers, are any two numbers closer to each other than\\n given threshold.\\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\\n False\\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\\n True\\n \\\"\\\"\\\"\",\n", | |
| " \"canonical_solution\": \" for idx, elem in enumerate(numbers):\\n for idx2, elem2 in enumerate(numbers):\\n if idx != idx2:\\n distance = abs(elem - elem2)\\n if distance < threshold:\\n return True\\n return False\\n\",\n", | |
| " \"test\": \"def check(candidate):\\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\\n assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\\n assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\\n assert candidate([1.0, 2.0, 3.0, 4.0, 5.0], 2.0) == True\\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\\n assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\\n\"\n", | |
| " },\n", | |
| " {\n", | |
| " \"task_id\": \"HumanEval/1\",\n", | |
| " \"prompt\": \"from typing import List\\n\\ndef separate_paren_groups(paren_string: str) -> List[str]:\\n \\\"\\\"\\\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\\n separate those group into separate strings and return the list of those.\\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\\n Ignore any spaces in the input string.\\n >>> separate_paren_groups('( ) (( )) (( )( ))')\\n ['()', '(())', '(()())']\\n \\\"\\\"\\\"\",\n", | |
| " \"canonical_solution\": \" result = []\\n current_string = []\\n current_depth = 0\\n\\n for c in paren_string:\\n if c == '(':\\n current_depth += 1\\n current_string.append(c)\\n elif c == ')':\\n current_depth -= 1\\n current_string.append(c)\\n\\n if current_depth == 0:\\n result.append(''.join(current_string))\\n current_string = []\\n\\n return result\\n\",\n", | |
| " \"test\": \"def check(candidate):\\n assert candidate('(()()) ((())) () ((())()())') == ['(()())', '((()))', '()', '((())()())']\\n assert candidate('() (()) ((())) (((())))') == ['()', '(())', '((()))', '(((())))']\\n assert candidate('(()(()))') == ['(()(()))']\\n\"\n", | |
| " },\n", | |
| " {\n", | |
| " \"task_id\": \"HumanEval/4\",\n", | |
| " \"prompt\": \"from typing import List\\n\\ndef mean_absolute_deviation(numbers: List[float]) -> float:\\n \\\"\\\"\\\" For a given list of input numbers, calculate Mean Absolute Deviation\\n around the mean of this dataset.\\n Mean Absolute Deviation is the average absolute difference between each\\n element and a centerpoint (mean in this case):\\n MAD = average | x - x_mean |\\n >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])\\n 1.0\\n \\\"\\\"\\\"\",\n", | |
| " \"canonical_solution\": \" mean = sum(numbers) / len(numbers)\\n return sum(abs(x - mean) for x in numbers) / len(numbers)\\n\",\n", | |
| " \"test\": \"def check(candidate):\\n assert abs(candidate([1.0, 2.0, 3.0]) - 2.0/3.0) < 1e-6\\n assert abs(candidate([1.0, 2.0, 3.0, 4.0]) - 1.0) < 1e-6\\n assert abs(candidate([1.0, 2.0, 3.0, 4.0, 5.0]) - 6.0/5.0) < 1e-6\\n\"\n", | |
| " }\n", | |
| "]\n", | |
| "\n", | |
| "def run_humaneval_test(model, tokenizer, problem):\n", | |
| " \"\"\"Run a single HumanEval problem.\"\"\"\n", | |
| " messages = [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert Python programmer. Complete the function. Output ONLY the function body, no explanations.\"},\n", | |
| " {\"role\": \"user\", \"content\": f\"Complete this function:\\n\\n{problem['prompt']}\"}\n", | |
| " ]\n", | |
| " \n", | |
| " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", | |
| " inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.2, do_sample=True)\n", | |
| " \n", | |
| " response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| " \n", | |
| " # Extract code\n", | |
| " if '```python' in response:\n", | |
| " code = response.split('```python')[1].split('```')[0]\n", | |
| " elif '```' in response:\n", | |
| " code = response.split('```')[1].split('```')[0]\n", | |
| " else:\n", | |
| " code = response\n", | |
| " \n", | |
| " # Try to run test\n", | |
| " try:\n", | |
| " full_code = problem['prompt'] + code + '\\n' + problem['test'] + f\"\\ncheck({problem['prompt'].split('def ')[1].split('(')[0]})\"\n", | |
| " exec(full_code, {})\n", | |
| " return True, code\n", | |
| " except Exception as e:\n", | |
| " return False, str(e)\n", | |
| "\n", | |
| "# Run HumanEval subset\n", | |
| "humaneval_results = []\n", | |
| "for prob in humaneval_problems:\n", | |
| " passed, result = run_humaneval_test(model, tokenizer, prob)\n", | |
| " humaneval_results.append(passed)\n", | |
| " status = '✓ PASS' if passed else '✗ FAIL'\n", | |
| " print(f\"{prob['task_id']}: {status}\")\n", | |
| "\n", | |
| "humaneval_score = sum(humaneval_results) / len(humaneval_results) * 100\n", | |
| "print(f\"\\nHumanEval (subset): {humaneval_score:.1f}% ({sum(humaneval_results)}/{len(humaneval_results)})\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# BENCHMARK 2: MBPP (Mostly Basic Python Problems)\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"BENCHMARK: MBPP (subset)\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "# Load MBPP test split\n", | |
| "try:\n", | |
| " mbpp_test = load_dataset(\"mbpp\", split=\"test\", trust_remote_code=True)\n", | |
| " mbpp_test = mbpp_test.select(range(min(20, len(mbpp_test)))) # Test on 20 samples\n", | |
| " \n", | |
| " def run_mbpp_test(model, tokenizer, example):\n", | |
| " \"\"\"Run a single MBPP problem.\"\"\"\n", | |
| " messages = [\n", | |
| " {\"role\": \"system\", \"content\": \"You are an expert Python programmer. Write a function to solve the problem. Output ONLY the Python code.\"},\n", | |
| " {\"role\": \"user\", \"content\": example['text']}\n", | |
| " ]\n", | |
| " \n", | |
| " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", | |
| " inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.2, do_sample=True)\n", | |
| " \n", | |
| " response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n", | |
| " \n", | |
| " # Extract code\n", | |
| " if '```python' in response:\n", | |
| " code = response.split('```python')[1].split('```')[0]\n", | |
| " elif '```' in response:\n", | |
| " code = response.split('```')[1].split('```')[0]\n", | |
| " else:\n", | |
| " code = response\n", | |
| " \n", | |
| " # Run tests\n", | |
| " try:\n", | |
| " test_code = code + '\\n' + '\\n'.join(example['test_list'])\n", | |
| " exec(test_code, {})\n", | |
| " return True\n", | |
| " except:\n", | |
| " return False\n", | |
| " \n", | |
| " mbpp_results = []\n", | |
| " for i, ex in enumerate(mbpp_test):\n", | |
| " passed = run_mbpp_test(model, tokenizer, ex)\n", | |
| " mbpp_results.append(passed)\n", | |
| " if (i+1) % 5 == 0:\n", | |
| " print(f\"Progress: {i+1}/{len(mbpp_test)} - Running accuracy: {sum(mbpp_results)/len(mbpp_results)*100:.1f}%\")\n", | |
| " \n", | |
| " mbpp_score = sum(mbpp_results) / len(mbpp_results) * 100\n", | |
| " print(f\"\\nMBPP (subset): {mbpp_score:.1f}% ({sum(mbpp_results)}/{len(mbpp_results)})\")\n", | |
| " \n", | |
| "except Exception as e:\n", | |
| " print(f\"Could not run MBPP benchmark: {e}\")\n", | |
| " mbpp_score = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# BENCHMARK 3: S-NIAH (Single Needle in a Haystack) - from RULER\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"BENCHMARK: S-NIAH (RULER style)\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "def generate_ruler_niah(context_length: int, needle_depth: float = 0.5):\n", | |
| " \"\"\"\n", | |
| " Generate RULER-style needle-in-a-haystack test.\n", | |
| " \n", | |
| " Args:\n", | |
| " context_length: Target context length in chars\n", | |
| " needle_depth: Where to place needle (0.0 = start, 1.0 = end)\n", | |
| " \"\"\"\n", | |
| " # The needle\n", | |
| " magic_number = random.randint(10000, 99999)\n", | |
| " needle = f\"The special magic number is {magic_number}.\"\n", | |
| " \n", | |
| " # Generate haystack (Paul Graham essays style filler)\n", | |
| " filler_sentences = [\n", | |
| " \"The best way to predict the future is to invent it.\",\n", | |
| " \"Innovation distinguishes between a leader and a follower.\",\n", | |
| " \"Stay hungry, stay foolish.\",\n", | |
| " \"The only way to do great work is to love what you do.\",\n", | |
| " \"Design is not just what it looks like and feels like.\",\n", | |
| " \"Design is how it works.\",\n", | |
| " \"Simple can be harder than complex.\",\n", | |
| " \"You have to work hard to get your thinking clean.\",\n", | |
| " \"The people who are crazy enough to think they can change the world are the ones who do.\",\n", | |
| " \"Quality is more important than quantity.\",\n", | |
| " \"One home run is much better than two doubles.\",\n", | |
| " ]\n", | |
| " \n", | |
| " haystack = []\n", | |
| " current_len = 0\n", | |
| " needle_pos = int(context_length * needle_depth)\n", | |
| " needle_inserted = False\n", | |
| " \n", | |
| " while current_len < context_length:\n", | |
| " if not needle_inserted and current_len >= needle_pos:\n", | |
| " haystack.append(needle)\n", | |
| " current_len += len(needle)\n", | |
| " needle_inserted = True\n", | |
| " else:\n", | |
| " sentence = random.choice(filler_sentences)\n", | |
| " haystack.append(sentence)\n", | |
| " current_len += len(sentence) + 1\n", | |
| " \n", | |
| " context = ' '.join(haystack)\n", | |
| " question = \"What is the special magic number mentioned in the text?\"\n", | |
| " \n", | |
| " return context, question, str(magic_number)\n", | |
| "\n", | |
| "def run_niah_benchmark(model, tokenizer, context_lengths=[8000, 16000, 32000], n_samples=3):\n", | |
| " \"\"\"Run S-NIAH at different context lengths.\"\"\"\n", | |
| " results = {}\n", | |
| " \n", | |
| " for ctx_len in context_lengths:\n", | |
| " print(f\"\\nTesting context length: {ctx_len:,} chars\")\n", | |
| " correct = 0\n", | |
| " \n", | |
| " for i in range(n_samples):\n", | |
| " # Random needle depth\n", | |
| " depth = random.uniform(0.2, 0.8)\n", | |
| " context, question, expected = generate_ruler_niah(ctx_len, depth)\n", | |
| " \n", | |
| " # Run with RLM\n", | |
| " answer = run_rlm_inference(\n", | |
| " model, tokenizer, \n", | |
| " context=context, \n", | |
| " question=question,\n", | |
| " verbose=False,\n", | |
| " max_iterations=5\n", | |
| " )\n", | |
| " \n", | |
| " # Check if answer contains the magic number\n", | |
| " is_correct = expected in str(answer)\n", | |
| " correct += int(is_correct)\n", | |
| " \n", | |
| " status = '✓' if is_correct else '✗'\n", | |
| " print(f\" Sample {i+1}: {status} (depth={depth:.2f}, expected={expected}, got={answer[:50]}...)\")\n", | |
| " \n", | |
| " accuracy = correct / n_samples * 100\n", | |
| " results[ctx_len] = accuracy\n", | |
| " print(f\" Accuracy at {ctx_len:,}: {accuracy:.1f}%\")\n", | |
| " \n", | |
| " return results\n", | |
| "\n", | |
| "# Run S-NIAH benchmark\n", | |
| "niah_results = run_niah_benchmark(\n", | |
| " model, tokenizer,\n", | |
| " context_lengths=[8000, 16000, 32000],\n", | |
| " n_samples=3\n", | |
| ")\n", | |
| "\n", | |
| "print(\"\\nS-NIAH Results:\")\n", | |
| "for ctx_len, acc in niah_results.items():\n", | |
| " print(f\" {ctx_len:>6,} chars: {acc:.1f}%\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# BENCHMARK 4: Multi-Key Retrieval (RULER style)\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"BENCHMARK: Multi-Key Retrieval\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "def generate_multikey_test(context_length: int, n_keys: int = 5):\n", | |
| " \"\"\"\n", | |
| " Generate multi-key retrieval test.\n", | |
| " Multiple key-value pairs scattered in text, must retrieve all.\n", | |
| " \"\"\"\n", | |
| " # Generate key-value pairs\n", | |
| " keys = [f\"KEY_{i}\" for i in range(n_keys)]\n", | |
| " values = [random.randint(100, 999) for _ in range(n_keys)]\n", | |
| " kv_pairs = list(zip(keys, values))\n", | |
| " \n", | |
| " # Filler text\n", | |
| " filler = [\n", | |
| " \"The system processes data efficiently.\",\n", | |
| " \"Users can access the dashboard anytime.\",\n", | |
| " \"Performance metrics are tracked daily.\",\n", | |
| " \"Security protocols are regularly updated.\",\n", | |
| " \"The API supports multiple endpoints.\",\n", | |
| " ]\n", | |
| " \n", | |
| " # Build context with scattered KV pairs\n", | |
| " parts = []\n", | |
| " current_len = 0\n", | |
| " kv_positions = sorted(random.sample(range(10, 90), n_keys)) # Percentages\n", | |
| " kv_idx = 0\n", | |
| " \n", | |
| " while current_len < context_length:\n", | |
| " progress = current_len / context_length * 100\n", | |
| " \n", | |
| " # Insert KV pair at designated position\n", | |
| " if kv_idx < n_keys and progress >= kv_positions[kv_idx]:\n", | |
| " k, v = kv_pairs[kv_idx]\n", | |
| " parts.append(f\"Important: {k} = {v}.\")\n", | |
| " kv_idx += 1\n", | |
| " else:\n", | |
| " parts.append(random.choice(filler))\n", | |
| " \n", | |
| " current_len = sum(len(p) for p in parts)\n", | |
| " \n", | |
| " context = ' '.join(parts)\n", | |
| " question = f\"Find all the KEY values (KEY_0 through KEY_{n_keys-1}) and list them.\"\n", | |
| " expected = {k: v for k, v in kv_pairs}\n", | |
| " \n", | |
| " return context, question, expected\n", | |
| "\n", | |
| "def run_multikey_benchmark(model, tokenizer, n_samples=3):\n", | |
| " \"\"\"Run multi-key retrieval benchmark.\"\"\"\n", | |
| " results = []\n", | |
| " \n", | |
| " for i in range(n_samples):\n", | |
| " n_keys = random.randint(3, 5)\n", | |
| " context, question, expected = generate_multikey_test(20000, n_keys)\n", | |
| " \n", | |
| " print(f\"\\nSample {i+1}: {n_keys} keys to find\")\n", | |
| " \n", | |
| " # Run with RLM\n", | |
| " answer = run_rlm_inference(\n", | |
| " model, tokenizer,\n", | |
| " context=context,\n", | |
| " question=question,\n", | |
| " verbose=False,\n", | |
| " max_iterations=8\n", | |
| " )\n", | |
| " \n", | |
| " # Count how many keys were found correctly\n", | |
| " found = 0\n", | |
| " for k, v in expected.items():\n", | |
| " if str(v) in str(answer):\n", | |
| " found += 1\n", | |
| " print(f\" ✓ Found {k}={v}\")\n", | |
| " else:\n", | |
| " print(f\" ✗ Missing {k}={v}\")\n", | |
| " \n", | |
| " recall = found / len(expected) * 100\n", | |
| " results.append(recall)\n", | |
| " print(f\" Recall: {recall:.1f}% ({found}/{len(expected)})\")\n", | |
| " \n", | |
| " avg_recall = sum(results) / len(results)\n", | |
| " print(f\"\\nMulti-Key Average Recall: {avg_recall:.1f}%\")\n", | |
| " return avg_recall\n", | |
| "\n", | |
| "multikey_score = run_multikey_benchmark(model, tokenizer, n_samples=3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# =============================================================================\n", | |
| "# BENCHMARK SUMMARY\n", | |
| "# =============================================================================\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"BENCHMARK SUMMARY\")\n", | |
| "print(\"=\"*60)\n", | |
| "\n", | |
| "print(\"\\n### Coding Benchmarks ###\")\n", | |
| "print(f\"HumanEval (subset): {humaneval_score:.1f}%\")\n", | |
| "if mbpp_score is not None:\n", | |
| " print(f\"MBPP (subset): {mbpp_score:.1f}%\")\n", | |
| "\n", | |
| "print(\"\\n### Long-Context Benchmarks (RLM) ###\")\n", | |
| "print(\"S-NIAH (needle-in-haystack):\")\n", | |
| "for ctx_len, acc in niah_results.items():\n", | |
| " print(f\" {ctx_len:>6,} chars: {acc:.1f}%\")\n", | |
| "print(f\"Multi-Key Retrieval: {multikey_score:.1f}%\")\n", | |
| "\n", | |
| "print(\"\\n\" + \"=\"*60)\n", | |
| "print(\"Note: These are subset evaluations. For full benchmarks, use:\")\n", | |
| "print(\" - bigcode/bigcode-evaluation-harness for HumanEval/MBPP\")\n", | |
| "print(\" - RULER benchmark for comprehensive long-context eval\")\n", | |
| "print(\"=\"*60)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "---\n", | |
| "## Summary\n", | |
| "\n", | |
| "This notebook fine-tuned **Qwen3-4B-Instruct-2507** with:\n", | |
| "\n", | |
| "### Training Data:\n", | |
| "| Dataset | Examples | Purpose |\n", | |
| "|---------|----------|----------|\n", | |
| "| Evol-Instruct-Code | ~5,000 | Complex coding instructions |\n", | |
| "| CodeSearchNet | ~5,000 | Function understanding |\n", | |
| "| MBPP | ~400 | Python problems |\n", | |
| "| OpenHermes | ~3,000 | General code tasks |\n", | |
| "| **RLM Trajectories** | **2,000** | REPL interaction patterns |\n", | |
| "\n", | |
| "### Benchmark Results:\n", | |
| "| Benchmark | Score | Notes |\n", | |
| "|-----------|-------|-------|\n", | |
| "| HumanEval | X% | Code generation |\n", | |
| "| MBPP | X% | Python problems |\n", | |
| "| S-NIAH 8K | X% | Needle in haystack |\n", | |
| "| S-NIAH 32K | X% | Long context |\n", | |
| "| Multi-Key | X% | Multiple retrieval |\n", | |
| "\n", | |
| "### The model now knows how to:\n", | |
| "1. Write excellent Python code\n", | |
| "2. Use ` ```repl ` blocks to examine context\n", | |
| "3. Call `llm_query()` for recursive sub-calls\n", | |
| "4. Output `FINAL()` / `FINAL_VAR()` properly\n", | |
| "5. Handle infinite context via RLM patterns" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.12.0" | |
| }, | |
| "accelerator": "GPU", | |
| "colab": { | |
| "gpuType": "T4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment