Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save Slyracoon23/348e35f6268fa9259e93b6d2e88925d5 to your computer and use it in GitHub Desktop.

Select an option

Save Slyracoon23/348e35f6268fa9259e93b6d2e88925d5 to your computer and use it in GitHub Desktop.
Iterative Masked Reasoning (IMR) with Group Reward Preference Optimization
# %% [markdown]
# # Qwen3-1.7B GRPO Training for ARC-AGI Problems
# ## Iterative Masked Reasoning (IMR) with Group Reward Preference Optimization
#
# **Goal:** Train Qwen3-1.7B to solve ARC-AGI visual reasoning puzzles using GRPO
#
# **Approach:** Use structured reasoning with masking and reward-based learning
#
# ### What is GRPO?
# - **Group Reward Preference Optimization** is a reinforcement learning technique
# - It uses multiple reward functions to guide model training
# - Better than traditional fine-tuning for complex reasoning tasks
#
# ### What is IMR?
# - **Iterative Masked Reasoning** applies BERT-style masking to reasoning tasks
# - Forces the model to predict missing pieces through structured thinking
# - Teaches "what rule applies here?" instead of just "what word comes next?"
# %% [markdown]
# ## πŸ“¦ Import Libraries and Setup Configuration
#
# First, we'll import all necessary libraries and set up our training configuration.
# %%
import json
import torch
import numpy as np
import random
import re
from pathlib import Path
from datasets import Dataset
import pandas as pd
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig, GRPOConfig, GRPOTrainer
from vllm import SamplingParams
from unsloth import vLLMSamplingParams
# Configuration
max_seq_length = 2048 # Reduced from 4096 to save memory
lora_rank = 32
data_dir = Path('../arc-prize-2025-dataset')
print("βœ… Libraries imported successfully!")
print(f"πŸ“ Max sequence length: {max_seq_length}")
print(f"πŸ”§ LoRA rank: {lora_rank}")
# %% [markdown]
# ## πŸ€– Load Qwen3-1.7B Model
#
# Now we'll load the Qwen3-1.7B base model and configure it for LoRA fine-tuning.
#
# **Why Qwen3-1.7B?**
# - Has built-in reasoning capabilities (`<think>` mode)
# - Optimized for both thinking and non-thinking modes
# - Good balance between performance and resource usage
#
# **Why LoRA?**
# - **L**ow-**R**ank **A**daptation is memory efficient
# - Only trains a small subset of parameters
# - Faster training while maintaining quality
# %%
print("Loading Qwen3-1.7B model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-1.7B",
max_seq_length = max_seq_length,
load_in_4bit = False,
fast_inference = True,
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.7,
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = lora_rank*2,
use_gradient_checkpointing = True,
random_state = 3407,
)
print("βœ… Model loaded and configured with LoRA for Qwen3-1.7B!")
# %% [markdown]
# ## πŸ’¬ Setup System Prompt for Native Reasoning
#
# We'll use Qwen3-1.7B's native `<think>...</think>` reasoning format, combined with a color legend and a boxed solution format:
#
# 1. **Thinking Section**: `<think>` ... `</think>` (model's native format)
# 2. **Solution Section**: `\boxed{...}` (Python list of lists, no extra text)
# 3. **Color Legend**: Included in the system prompt for reference
# 4. **Example**: Provided in the system prompt to illustrate the expected input/output/answer format
#
# This leverages the model's built-in reasoning capabilities while maintaining a clear solution format for reward functions and evaluation.
# %%
# Use model's native thinking tokens
thinking_start = "<think>"
thinking_end = "</think>"
# ARC color palette mapping (updated to match the image)
color_legend = '''
ARC Color Palette:
0 = black
1 = blue
2 = red
3 = green
4 = yellow
5 = gray
6 = magenta
7 = orange
8 = cyan
9 = maroon
'''
# Single-shot example (from attached data)
example_input = """Input:
[[7,9],
[4,3]]"""
example_output = """Output:
[[7,9,7,9,7,9],
[4,3,4,3,4,3],
[9,7,9,7,9,7],
[3,4,3,4,3,4],
[7,9,7,9,7,9],
[4,3,4,3,4,3]]"""
example_boxed = """\\boxed{[[7,9,7,9,7,9],[4,3,4,3,4,3],[9,7,9,7,9,7],[3,4,3,4,3,4],[7,9,7,9,7,9],[4,3,4,3,4,3]]}"""
system_prompt = f"""You are an expert at visual pattern recognition and reasoning.
You will be given ARC-AGI puzzle examples showing input-output grid transformations.
Use the thinking mode to analyze the pattern step by step, then provide your solution grid inside a \\boxed{{...}} block.
If you see any <|reserved_0|> tokens in the grid, your task is to figure out and fill in the correct value for each masked token based on the patterns you observe.
{color_legend}
Format the grid as a Python list of lists, e.g., \\boxed{{[[1,2,3],[4,5,6],[7,8,9]]}}. Do not include any extra text or explanation after the box.
Here is an example:
{example_input}
{example_output}
Answer:
{example_boxed}
"""
print("βœ… System prompt configured with native reasoning, color legend, and example!")
print(f"🧠 Using native thinking tags: {thinking_start} ... {thinking_end}")
print(f"🎯 Solution format: \\boxed{{...}}")
# %% [markdown]
# ## πŸ“Š Load and Explore ARC-AGI Data
#
# The ARC-AGI dataset contains visual reasoning puzzles where:
# - Each puzzle has training examples (input β†’ output transformations)
# - The goal is to learn the pattern and apply it to test cases
#
# **Dataset Structure:**
# - `training_challenges.json`: Contains the puzzle inputs and training examples
# - `training_solutions.json`: Contains the correct outputs for test cases
# %%
print("Loading ARC-AGI data...")
with open(data_dir / 'arc-agi_training_challenges.json') as f:
training_challenges = json.load(f)
with open(data_dir / 'arc-agi_training_solutions.json') as f:
training_solutions = json.load(f)
print(f"βœ… Loaded {len(training_challenges)} puzzles")
print(f"πŸ“‹ Solutions available for {len(training_solutions)} puzzles")
# Let's look at a sample puzzle
sample_id = list(training_challenges.keys())[0]
sample_puzzle = training_challenges[sample_id]
print(f"\nπŸ” Sample puzzle {sample_id}:")
print(f" πŸ“š Training examples: {len(sample_puzzle['train'])}")
print(f" πŸ§ͺ Test cases: {len(sample_puzzle['test'])}")
# %% [markdown]
# ## πŸ”§ Data Processing Functions
#
# We need helper functions to:
# 1. Convert grids to readable text format
# 2. Format ARC problems for training
# 3. Apply masking for the IMR approach
# %%
def grid_to_string(grid):
"""Convert grid to Python [[]] object string format"""
return str(grid)
def format_arc_problem(puzzle_id, challenge, solution):
"""Format ARC problem for training"""
train_examples = challenge['train']
test_input = challenge['test'][0]['input']
test_output = solution[0] # First solution
# Create training examples description
examples_text = ""
for i, example in enumerate(train_examples):
examples_text += f"Example {i+1}:\n"
examples_text += f"Input:\n{grid_to_string(example['input'])}\n"
examples_text += f"Output:\n{grid_to_string(example['output'])}\n\n"
# Create test problem (change to "Test:\nInput:" as requested)
problem_text = f"""Here are some example input-output transformations:
{examples_text}Test:
Input:
{grid_to_string(test_input)}
"""
return problem_text, test_output
def mask_grid_tokens(grid, mask_ratio=0.2):
"""Apply masking to output grid for IMR approach"""
masked_grid = [row[:] for row in grid]
mask_positions = []
for i in range(len(grid)):
for j in range(len(grid[0])):
if random.random() < mask_ratio:
masked_grid[i][j] = -1 # Use -1 for dataset storage
mask_positions.append((i, j))
return masked_grid, mask_positions
print("βœ… Data processing functions defined!")
# Test the functions
test_problem, test_solution = format_arc_problem(sample_id, sample_puzzle, training_solutions[sample_id])
print(f"\nπŸ§ͺ Test problem preview:")
print(test_problem)
# %% [markdown]
# ## πŸ“ Prepare Training Dataset
#
# Now we'll create training examples by:
# 1. Taking the first 100 puzzles (for faster training)
# 2. Formatting each puzzle with examples and test cases
# 3. Creating structured responses with native thinking format
# %%
# Prepare dataset
dataset_entries = []
puzzle_ids = list(training_challenges.keys())[:100] # Use first 100 for faster training
# Define the mask token for text display
MASK_TOKEN = "<|reserved_0|>"
for puzzle_id in puzzle_ids:
challenge = training_challenges[puzzle_id]
if puzzle_id in training_solutions:
solution = training_solutions[puzzle_id]
try:
problem_text, test_output = format_arc_problem(puzzle_id, challenge, solution)
# Sample a random mask_ratio for this example
this_mask_ratio = random.uniform(0, 1)
masked_grid, _ = mask_grid_tokens(test_output, mask_ratio=this_mask_ratio)
# Create display version with reserved token for text
masked_grid_display = [[MASK_TOKEN if v == -1 else v for v in row] for row in masked_grid]
masked_grid_str = str(masked_grid_display)
# Calculate actual masked ratio
total_elements = sum(len(row) for row in test_output)
num_masked = sum(v == -1 for row in masked_grid for v in row)
actual_masked_ratio = num_masked / total_elements if total_elements > 0 else 0.0
# Add the masked output to the user prompt
problem_text_with_mask = problem_text + f"Masked Output:\n{masked_grid_str}\n"
# Create reasoning response with native thinking format and box
reasoning_response = f"{thinking_start}Looking at the examples, I need to identify the transformation pattern. Let me analyze each example systematically to understand the rule.{thinking_end}\\boxed{{{test_output}}}"
dataset_entries.append({
"prompt": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": problem_text_with_mask},
],
"answer": str(test_output),
"puzzle_id": puzzle_id,
"masked_ratio": actual_masked_ratio,
"masked_grid": masked_grid # Store -1 values, not string tokens
})
except Exception as e:
print(f"❌ Error processing puzzle {puzzle_id}: {e}")
continue
print(f"βœ… Created {len(dataset_entries)} training examples")
# Create dataset
df = pd.DataFrame(dataset_entries)
dataset = Dataset.from_pandas(df)
print(f"πŸ“Š Dataset shape: {len(dataset)} examples")
# %% [markdown]
# ## 🎯 Pre-training for Format Learning (Work in Progress)
#
# This section is a work in progress.
#
# Before GRPO, we do a quick supervised fine-tuning (SFT) to teach the model our format.
#
# **Why pre-training?**
# - GRPO works better when the model already understands the format
# - Speeds up the main GRPO training
# - Ensures consistent output structure
# %%
print("Starting pre-training for format learning...")
pre_train_data = []
for entry in dataset_entries[:20]: # Use small subset for pre-training
reasoning_response = f"{thinking_start}I need to analyze the pattern in the examples and apply it to the test case.{thinking_end}\\boxed{{{entry['answer']}}}"
messages = entry["prompt"] + [{"role": "assistant", "content": reasoning_response}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
enable_thinking=True # Enable native thinking mode
)
pre_train_data.append({"text": text})
pre_dataset = Dataset.from_list(pre_train_data)
# Pre-training
trainer = SFTTrainer(
model = model,
processing_class = tokenizer,
train_dataset = pre_dataset,
args = SFTConfig(
dataset_text_field = "text",
per_device_train_batch_size = 1,
gradient_accumulation_steps = 2,
warmup_steps = 5,
num_train_epochs = 1,
learning_rate = 2e-4,
logging_steps = 5,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
report_to = "none",
output_dir = "pre_training_outputs",
),
)
trainer.train()
print("βœ… Pre-training completed!")
# Clean up before GRPO
del pre_dataset, trainer
torch.cuda.empty_cache()
# %% [markdown]
# ## πŸ† Setup GRPO Reward Functions
#
# GRPO uses multiple reward functions to guide training:
#
# 1. **Format Rewards**: Does the model follow our reasoning format?
# 2. **Solution Rewards**: Is the actual answer correct?
# 3. **Debug Function**: Prints examples during training
#
# **How GRPO Works:**
# - Model generates multiple responses
# - Each response gets scored by reward functions
# - Model learns to maximize total reward
# %%
# Regex patterns for matching solutions with native thinking format
import re
box_regex = re.compile(r"\\boxed\{(.+?)\}")
def parse_grid_from_response(response):
"""Extract and parse grid from \\boxed{...} in response"""
match = box_regex.search(response)
if not match:
return None
try:
return eval(match.group(1))
except:
return None
def count_correct_mask_predictions(predicted_grid, true_grid, masked_grid):
"""Count how many masked positions were predicted correctly"""
if not predicted_grid or not isinstance(predicted_grid, list):
return 0
correct_count = 0
for i, row in enumerate(masked_grid):
for j, cell in enumerate(row):
if cell == -1: # Check for -1 instead of "<MASK>"
# Check bounds and correctness
if (i < len(predicted_grid) and j < len(predicted_grid[i]) and
i < len(true_grid) and j < len(true_grid[i]) and
predicted_grid[i][j] == true_grid[i][j]):
correct_count += 1
return correct_count
def structure_and_mask_reward(completions, answer, **kwargs):
"""Reward for required structure and correct mask labeling"""
scores = []
true_grid = eval(answer) if isinstance(answer, str) else answer
masked_grid = kwargs.get('masked_grid')
for completion in completions:
score = 0
response = completion[0]["content"]
# Debug print: show model's response and expected answer during training
print("\n[DEBUG] Model response during training:")
print(response)
print("[DEBUG] Expected answer:")
print(answer)
print("-" * 60)
# Structure penalties (-10 each)
if "<think>" not in response or "</think>" not in response:
score -= 10
if "\\boxed{" not in response:
score -= 10
# Mask prediction rewards (+1 per correct)
if masked_grid:
predicted_grid = parse_grid_from_response(response)
score += count_correct_mask_predictions(predicted_grid, true_grid, masked_grid)
scores.append(score)
return scores
# %% [markdown]
# ## 🎯 GRPO Training Configuration
#
# We'll configure GRPO training with:
# 1. **Temperature**: Controls randomness of responses
# 2. **Learning Rate**: How quickly the model learns
# 3. **Weight Decay**: Prevents overfitting
# 4. **Warmup Ratio**: Gradually increases learning rate
# %%
# GRPO Training configuration (move this above trainer)
max_prompt_length = 256 # Reduced from 512 to save memory
max_completion_length = max_seq_length - max_prompt_length
# Modern vLLM sampling params (2025)
vllm_generation_kwargs = {
"min_p": 0.0,
"seed": 3407,
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"stop": [tokenizer.eos_token],
"include_stop_str_in_output": True,
}
training_args = GRPOConfig(
# Output and logging
output_dir="arc_grpo_outputs",
logging_steps=1,
save_steps=50,
report_to="none",
# Training hyperparameters
learning_rate=3e-6,
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_8bit",
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
max_steps=200,
# Generation settings
temperature=0.6,
top_p=0.95,
top_k=20,
num_generations=2,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
generation_kwargs=vllm_generation_kwargs, # Pass vLLM sampling params here
# vLLM settings (2025 best practice)
use_vllm=True,
vllm_mode="colocate", # Use in-process vLLM for training
vllm_gpu_memory_utilization=0.6, # Adjust as needed
# float8_kv_cache is set at model load time if supported by your model loader
# Optionally, for regex-guided decoding:
# vllm_guided_decoding_regex=r"<think>(.*?)</think>\\boxed\{(.+?)\}",
# Additional config options can be set as needed
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=structure_and_mask_reward,
args=training_args,
train_dataset=dataset,
)
print("βœ… GRPO trainer configured!")
# %% [markdown]
# ## πŸŽ“ Run GRPO Training
#
# This is the main training loop. The model will:
# 1. Generate multiple responses for each puzzle
# 2. Get scored by our reward functions
# 3. Learn to maximize total reward
#
# **What to expect:**
# - Initial rewards will be low/negative
# - Format compliance should improve first
# - Solution accuracy should improve gradually
# - Training may take 30-60 minutes depending on hardware
# %%
print("πŸš€ Starting GRPO training on ARC-AGI problems...")
print("⏳ This may take a while. Watch for reward improvements!")
trainer.train()
# %% [markdown]
# ## πŸ’Ύ Save the Trained Model
#
# Save our trained LoRA adapter so we can use it later.
# %%
print("πŸ’Ύ Saving trained model...")
model.save_lora("arc_grpo_lora")
print("βœ… Model saved as 'arc_grpo_lora'")
# %% [markdown]
# ## πŸ§ͺ Test the Trained Model
#
# Let's test our trained model on a puzzle to see how it performs!
# %%
print("πŸ§ͺ Testing the trained model...")
test_puzzle_id = puzzle_ids[0]
test_challenge = training_challenges[test_puzzle_id]
test_solution = training_solutions[test_puzzle_id]
problem_text, expected_output = format_arc_problem(test_puzzle_id, test_challenge, test_solution)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": problem_text},
]
text = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
tokenize = False,
enable_thinking = True, # Enable native thinking mode
)
sampling_params = SamplingParams(
temperature = 0.6, # Recommended for thinking mode
top_p = 0.95,
top_k = 20,
max_tokens = 2048,
)
print(f"\nπŸ” Testing on puzzle: {test_puzzle_id}")
print(f"🎯 Expected output: {expected_output}")
print("\nπŸ€– Model response:")
output = model.fast_generate(
text,
sampling_params = sampling_params,
lora_request = model.load_lora("arc_grpo_lora"),
)[0].outputs[0].text
print(output)
# %% [markdown]
# ## πŸŽ“ Training Complete!
#
# Congratulations! You've successfully trained a Qwen3-1.7B model on ARC-AGI problems using GRPO.
#
# ### What we accomplished:
# - βœ… Loaded and configured Qwen3-1.7B with LoRA
# - βœ… Used native `<think>...</think>` reasoning format
# - βœ… Processed ARC-AGI visual reasoning data
# - βœ… Pre-trained for format compliance
# - βœ… Trained with GRPO using multiple reward functions
# - βœ… Tested the final model
#
# ### Next steps:
# - πŸ”¬ Try the model on more test cases
# - πŸ“Š Evaluate performance on validation set
# - πŸ”§ Experiment with different reward functions
# - πŸ“ˆ Train for more steps for better performance
#
# ### Files created:
# - `arc_grpo_lora/`: Your trained LoRA adapter
# - `arc_grpo_outputs/`: Training checkpoints and logs
# %%
print("πŸŽ‰ Training completed! Model saved as 'arc_grpo_lora'")
print("\nπŸ“‹ Summary:")
print(f" οΏ½οΏ½ Puzzles processed: {len(dataset_entries)}")
print(f" πŸ“Š Examples used: {len(dataset)}")
print(f" πŸŽ“ Training steps: 200")
print(f" πŸ’Ύ Model saved: arc_grpo_lora/")
print("\nπŸš€ Ready to solve ARC-AGI puzzles!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment