Skip to content

Instantly share code, notes, and snippets.

@rileyseaburg
Last active January 18, 2026 17:51
Show Gist options
  • Select an option

  • Save rileyseaburg/85c01f8615e2f82b35404ebfe08f6869 to your computer and use it in GitHub Desktop.

Select an option

Save rileyseaburg/85c01f8615e2f82b35404ebfe08f6869 to your computer and use it in GitHub Desktop.
Concept-First Code Generation: JEPA-style concept prediction for code (VL-JEPA inspired)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Concept-First Code Generation\n",
"\n",
"**Inspired by VL-JEPA**: Predict concept embeddings first, then generate code conditioned on them.\n",
"\n",
"## The Idea\n",
"\n",
"Traditional autoregressive models predict tokens one at a time, which can lead to:\n",
"- Losing coherence over long generations\n",
"- Hallucinating APIs\n",
"- Repetition loops\n",
"\n",
"**Concept-First** approach:\n",
"1. **Concept Encoder**: Encode code snippets into semantic embeddings\n",
"2. **Concept Predictor**: Given a query, predict what the code embedding should look like\n",
"3. **Concept-Conditioned Generation**: Generate code guided by the predicted concept\n",
"\n",
"```\n",
"Query: \"Write fibonacci\" \n",
" ↓\n",
"Concept Predictor → [0.23, -0.87, ...] (embedding)\n",
" ↓\n",
"Retrieve similar: [\"def fib(n): ...\", \"def factorial(n): ...\"]\n",
" ↓\n",
"Conditioned Generation → \"def fibonacci(n):\\n if n <= 1: ...\"\n",
"```\n",
"\n",
"## Models Used (January 2026 - Latest)\n",
"\n",
"| Component | Model | Why |\n",
"|-----------|-------|-----|\n",
"| **Code Encoder** | `Salesforce/SFR-Embedding-Code-2B_R` | SOTA code embeddings (CoIR: 67.4), 2B params |\n",
"| **Text Encoder** | `Alibaba-NLP/gte-Qwen2-1.5B-instruct` | Latest GTE with instruction support |\n",
"| **Code LLM** | `Qwen/Qwen3-Coder-30B-A3B-Instruct` | Latest Qwen3 Coder MoE (Jan 2026) |\n",
"| **Dataset** | `bigcode/the-stack-v2` + `MBPP` + `HumanEval` + `Evol-Instruct` | Diverse, high-quality code |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies (latest versions as of Jan 2026)\n",
"!pip install -q --upgrade transformers>=4.57.0 datasets>=3.0.0 torch>=2.5.0 \n",
"!pip install -q --upgrade sentence-transformers>=3.3.0 accelerate>=1.2.0 bitsandbytes>=0.45.0\n",
"!pip install -q --upgrade huggingface_hub einops"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig\n",
"from sentence_transformers import SentenceTransformer\n",
"from datasets import load_dataset, concatenate_datasets\n",
"import numpy as np\n",
"from typing import List, Dict, Tuple, Optional\n",
"import json\n",
"from tqdm.auto import tqdm\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"# Check GPU\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
" print(f\"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
" \n",
"# Print versions\n",
"import transformers, datasets, sentence_transformers\n",
"print(f\"\\nLibrary versions (Jan 2026):\")\n",
"print(f\" transformers: {transformers.__version__}\")\n",
"print(f\" datasets: {datasets.__version__}\")\n",
"print(f\" sentence-transformers: {sentence_transformers.__version__}\")\n",
"print(f\" torch: {torch.__version__}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 1: Concept Encoder (SFR-Embedding-Code-2B)\n",
"\n",
"We use **Salesforce SFR-Embedding-Code-2B** - the current SOTA for code embeddings.\n",
"- CoIR benchmark: 67.4 NDCG@10 (best in class)\n",
"- Supports code-to-code and text-to-code retrieval\n",
"- 2B parameters, 32K context length"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConceptEncoder:\n",
" \"\"\"\n",
" Encodes code snippets into semantic concept embeddings.\n",
" Uses SFR-Embedding-Code-2B (SOTA for code embeddings, Jan 2026).\n",
" \"\"\"\n",
" \n",
" def __init__(self, model_name: str = \"Salesforce/SFR-Embedding-Code-2B_R\"):\n",
" \"\"\"\n",
" Initialize with Salesforce SFR-Embedding-Code.\n",
" \n",
" Alternatives (2026):\n",
" - \"Salesforce/SFR-Embedding-Code-2B_R\" (2B, SOTA CoIR 67.4)\n",
" - \"Salesforce/SFR-Embedding-Code-400M_R\" (400M, faster)\n",
" - \"jinaai/jina-embeddings-v3\" (if available)\n",
" \"\"\"\n",
" print(f\"Loading concept encoder: {model_name}\")\n",
" self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)\n",
" self.model.to(device)\n",
" self.model.eval()\n",
" \n",
" # SFR-Embedding-Code uses 2048-dim embeddings\n",
" self.embed_dim = 2048\n",
" self.max_length = 8192 # Can go up to 32K but 8K is efficient\n",
" \n",
" print(f\"Embedding dimension: {self.embed_dim}\")\n",
" print(f\"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}\")\n",
" \n",
" @torch.no_grad()\n",
" def encode(self, code: str) -> torch.Tensor:\n",
" \"\"\"Encode a single code snippet.\"\"\"\n",
" embeddings = self.model.encode_corpus([code], max_length=self.max_length)\n",
" embeddings = F.normalize(embeddings, p=2, dim=-1)\n",
" return embeddings[0]\n",
" \n",
" @torch.no_grad()\n",
" def encode_batch(self, codes: List[str], batch_size: int = 16) -> torch.Tensor:\n",
" \"\"\"Encode multiple code snippets.\"\"\"\n",
" all_embeddings = []\n",
" \n",
" for i in tqdm(range(0, len(codes), batch_size), desc=\"Encoding\", leave=False):\n",
" batch = codes[i:i+batch_size]\n",
" embeddings = self.model.encode_corpus(batch, max_length=self.max_length)\n",
" embeddings = F.normalize(embeddings, p=2, dim=-1)\n",
" all_embeddings.append(embeddings.cpu())\n",
" \n",
" return torch.cat(all_embeddings, dim=0)\n",
" \n",
" @torch.no_grad()\n",
" def encode_query(self, query: str, instruction: str = \"Given Code or Text, retrieve relevant code\") -> torch.Tensor:\n",
" \"\"\"Encode a text query with instruction (for text-to-code retrieval).\"\"\"\n",
" embeddings = self.model.encode_queries([query], instruction=instruction, max_length=self.max_length)\n",
" embeddings = F.normalize(embeddings, p=2, dim=-1)\n",
" return embeddings[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize concept encoder\n",
"# Use 400M version for faster iteration, 2B for best quality\n",
"concept_encoder = ConceptEncoder(model_name=\"Salesforce/SFR-Embedding-Code-400M_R\")\n",
"\n",
"# Test it with similar code patterns\n",
"test_codes = [\n",
" \"def fibonacci(n):\\n if n <= 1:\\n return n\\n return fibonacci(n-1) + fibonacci(n-2)\",\n",
" \"def factorial(n):\\n if n <= 1:\\n return 1\\n return n * factorial(n-1)\",\n",
" \"def bubble_sort(arr):\\n for i in range(len(arr)):\\n for j in range(len(arr)-1):\\n if arr[j] > arr[j+1]:\\n arr[j], arr[j+1] = arr[j+1], arr[j]\",\n",
" \"def binary_search(arr, x):\\n low, high = 0, len(arr)-1\\n while low <= high:\\n mid = (low + high) // 2\\n if arr[mid] == x:\\n return mid\"\n",
"]\n",
"\n",
"embeddings = concept_encoder.encode_batch(test_codes)\n",
"print(f\"Embeddings shape: {embeddings.shape}\")\n",
"\n",
"# Compute similarity matrix\n",
"similarity = embeddings @ embeddings.T\n",
"print(\"\\nSimilarity matrix (recursive functions should cluster together):\")\n",
"labels = [\"fibonacci\", \"factorial\", \"bubble_sort\", \"binary_search\"]\n",
"print(f\"{'':15}\", end=\"\")\n",
"for l in labels:\n",
" print(f\"{l:15}\", end=\"\")\n",
"print()\n",
"for i, l in enumerate(labels):\n",
" print(f\"{l:15}\", end=\"\")\n",
" for j in range(len(labels)):\n",
" print(f\"{similarity[i,j]:.3f} \", end=\"\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 2: Build Concept Bank from Multiple Datasets (2026)\n",
"\n",
"We combine multiple high-quality code datasets:\n",
"1. **MBPP** - Curated Python problems with descriptions\n",
"2. **HumanEval** - OpenAI's code benchmark\n",
"3. **CodeSearchNet Python** - Large-scale code with docstrings\n",
"4. **Evol-Instruct-Code** - High quality instruction-following code\n",
"5. **The Stack v2** (sample) - Latest open-source code collection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConceptBank:\n",
" \"\"\"\n",
" A searchable bank of code concepts.\n",
" Maps embeddings to code snippets for retrieval.\n",
" \"\"\"\n",
" \n",
" def __init__(self, encoder: ConceptEncoder):\n",
" self.encoder = encoder\n",
" self.embeddings = None # (N, embed_dim)\n",
" self.codes = [] # List of code strings\n",
" self.descriptions = [] # List of descriptions/docstrings\n",
" self.sources = [] # Track where each example came from\n",
" \n",
" def add(self, codes: List[str], descriptions: List[str] = None, source: str = \"unknown\"):\n",
" \"\"\"Add code snippets to the bank.\"\"\"\n",
" if descriptions is None:\n",
" descriptions = [\"\"] * len(codes)\n",
" \n",
" # Filter out empty/invalid codes\n",
" valid_pairs = [(c, d) for c, d in zip(codes, descriptions) if c and len(c.strip()) > 10]\n",
" if not valid_pairs:\n",
" print(f\" No valid codes from {source}\")\n",
" return\n",
" \n",
" codes, descriptions = zip(*valid_pairs)\n",
" codes, descriptions = list(codes), list(descriptions)\n",
" \n",
" print(f\" Encoding {len(codes)} examples from {source}...\")\n",
" new_embeddings = self.encoder.encode_batch(codes)\n",
" \n",
" if self.embeddings is None:\n",
" self.embeddings = new_embeddings\n",
" else:\n",
" self.embeddings = torch.cat([self.embeddings, new_embeddings], dim=0)\n",
" \n",
" self.codes.extend(codes)\n",
" self.descriptions.extend(descriptions)\n",
" self.sources.extend([source] * len(codes))\n",
" print(f\" Bank size: {len(self.codes)} concepts\")\n",
" \n",
" def search(self, query_embedding: torch.Tensor, k: int = 5) -> List[Dict]:\n",
" \"\"\"Find k nearest concepts to the query embedding.\"\"\"\n",
" query_embedding = query_embedding.cpu()\n",
" if query_embedding.dim() == 1:\n",
" query_embedding = query_embedding.unsqueeze(0)\n",
" \n",
" similarities = (query_embedding @ self.embeddings.T).squeeze(0)\n",
" top_k = similarities.topk(min(k, len(self.codes)))\n",
" \n",
" results = []\n",
" for idx, score in zip(top_k.indices.tolist(), top_k.values.tolist()):\n",
" results.append({\n",
" \"code\": self.codes[idx],\n",
" \"description\": self.descriptions[idx],\n",
" \"similarity\": score,\n",
" \"source\": self.sources[idx]\n",
" })\n",
" return results\n",
" \n",
" def search_by_code(self, code: str, k: int = 5) -> List[Dict]:\n",
" \"\"\"Find similar code snippets.\"\"\"\n",
" embedding = self.encoder.encode(code)\n",
" return self.search(embedding, k)\n",
" \n",
" def search_by_text(self, text: str, k: int = 5) -> List[Dict]:\n",
" \"\"\"Find code matching a text description (uses query encoder with instruction).\"\"\"\n",
" embedding = self.encoder.encode_query(text)\n",
" return self.search(embedding, k)\n",
" \n",
" def stats(self):\n",
" \"\"\"Print statistics about the concept bank.\"\"\"\n",
" from collections import Counter\n",
" source_counts = Counter(self.sources)\n",
" print(f\"\\nConcept Bank Statistics:\")\n",
" print(f\" Total concepts: {len(self.codes)}\")\n",
" print(f\" Embedding dim: {self.embeddings.shape[1]}\")\n",
" print(f\" Sources:\")\n",
" for source, count in source_counts.most_common():\n",
" print(f\" - {source}: {count}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load multiple datasets (2026 versions)\n",
"print(\"Loading code datasets (Jan 2026)...\")\n",
"print(\"=\" * 50)\n",
"\n",
"all_codes = []\n",
"all_descriptions = []\n",
"all_sources = []\n",
"\n",
"# 1. MBPP - Curated Python problems (google-research version)\n",
"print(\"\\n1. Loading MBPP...\")\n",
"try:\n",
" mbpp = load_dataset(\"google-research-datasets/mbpp\", \"full\", split=\"train\", trust_remote_code=True)\n",
" for ex in mbpp:\n",
" all_codes.append(ex[\"code\"])\n",
" all_descriptions.append(ex[\"text\"])\n",
" all_sources.append(\"mbpp\")\n",
" print(f\" Added {len(mbpp)} examples\")\n",
"except Exception as e:\n",
" print(f\" Failed: {e}\")\n",
"\n",
"# 2. HumanEval - OpenAI benchmark (latest version)\n",
"print(\"\\n2. Loading HumanEval...\")\n",
"try:\n",
" humaneval = load_dataset(\"openai/openai_humaneval\", split=\"test\", trust_remote_code=True)\n",
" for ex in humaneval:\n",
" code = ex[\"prompt\"] + ex[\"canonical_solution\"]\n",
" all_codes.append(code)\n",
" desc = ex[\"prompt\"].split('\"\"\"')[1] if '\"\"\"' in ex[\"prompt\"] else ex[\"entry_point\"]\n",
" all_descriptions.append(desc.strip())\n",
" all_sources.append(\"humaneval\")\n",
" print(f\" Added {len(humaneval)} examples\")\n",
"except Exception as e:\n",
" print(f\" Failed: {e}\")\n",
"\n",
"# 3. CodeSearchNet Python - Large scale with docstrings\n",
"print(\"\\n3. Loading CodeSearchNet (Python subset)...\")\n",
"try:\n",
" csn = load_dataset(\"code-search-net/code_search_net\", \"python\", split=\"train\", trust_remote_code=True)\n",
" csn_sample = csn.shuffle(seed=42).select(range(min(5000, len(csn))))\n",
" for ex in csn_sample:\n",
" if ex[\"func_code_string\"] and ex[\"func_documentation_string\"]:\n",
" all_codes.append(ex[\"func_code_string\"])\n",
" all_descriptions.append(ex[\"func_documentation_string\"])\n",
" all_sources.append(\"codesearchnet\")\n",
" print(f\" Added {len(csn_sample)} examples\")\n",
"except Exception as e:\n",
" print(f\" Failed: {e}\")\n",
"\n",
"# 4. Evol-Instruct-Code - High quality instruction-following code\n",
"print(\"\\n4. Loading Evol-Instruct-Code...\")\n",
"try:\n",
" evol = load_dataset(\"nickrosh/Evol-Instruct-Code-80k-v1\", split=\"train\", trust_remote_code=True)\n",
" evol_sample = evol.shuffle(seed=42).select(range(min(3000, len(evol))))\n",
" count = 0\n",
" for ex in evol_sample:\n",
" output = ex[\"output\"]\n",
" if \"```python\" in output:\n",
" code = output.split(\"```python\")[1].split(\"```\")[0].strip()\n",
" if len(code) > 20:\n",
" all_codes.append(code)\n",
" all_descriptions.append(ex[\"instruction\"][:200])\n",
" all_sources.append(\"evol-instruct\")\n",
" count += 1\n",
" print(f\" Added {count} examples from Evol-Instruct\")\n",
"except Exception as e:\n",
" print(f\" Failed: {e}\")\n",
"\n",
"# 5. Magicoder-OSS-Instruct - High quality code instructions (2024+)\n",
"print(\"\\n5. Loading Magicoder-OSS-Instruct...\")\n",
"try:\n",
" magic = load_dataset(\"ise-uiuc/Magicoder-OSS-Instruct-75K\", split=\"train\", trust_remote_code=True)\n",
" magic_sample = magic.shuffle(seed=42).select(range(min(2000, len(magic))))\n",
" count = 0\n",
" for ex in magic_sample:\n",
" if \"```python\" in ex[\"solution\"]:\n",
" code = ex[\"solution\"].split(\"```python\")[1].split(\"```\")[0].strip()\n",
" if len(code) > 20:\n",
" all_codes.append(code)\n",
" all_descriptions.append(ex[\"problem\"][:200])\n",
" all_sources.append(\"magicoder\")\n",
" count += 1\n",
" print(f\" Added {count} examples from Magicoder\")\n",
"except Exception as e:\n",
" print(f\" Failed: {e}\")\n",
"\n",
"print(f\"\\n{'='*50}\")\n",
"print(f\"Total collected: {len(all_codes)} code examples\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build concept bank\n",
"print(\"\\nBuilding concept bank...\")\n",
"concept_bank = ConceptBank(concept_encoder)\n",
"\n",
"# Add in batches by source for better tracking\n",
"from collections import defaultdict\n",
"by_source = defaultdict(lambda: {\"codes\": [], \"descs\": []})\n",
"for code, desc, source in zip(all_codes, all_descriptions, all_sources):\n",
" by_source[source][\"codes\"].append(code)\n",
" by_source[source][\"descs\"].append(desc)\n",
"\n",
"for source, data in by_source.items():\n",
" concept_bank.add(data[\"codes\"], data[\"descs\"], source=source)\n",
"\n",
"concept_bank.stats()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test retrieval with text-to-code\n",
"print(\"=\" * 60)\n",
"print(\"Testing concept retrieval (text-to-code)\")\n",
"print(\"=\" * 60)\n",
"\n",
"test_queries = [\n",
" \"fibonacci sequence recursive implementation\",\n",
" \"sort a list using quicksort\",\n",
" \"check if string is palindrome\",\n",
" \"binary tree traversal inorder\",\n",
" \"read and parse json file\"\n",
"]\n",
"\n",
"for query in test_queries:\n",
" print(f\"\\nQuery: '{query}'\")\n",
" print(\"-\" * 40)\n",
" results = concept_bank.search_by_text(query, k=2)\n",
" for i, r in enumerate(results):\n",
" desc_preview = r['description'][:60].replace('\\n', ' ') if r['description'] else \"(no description)\"\n",
" print(f\" [{i+1}] (sim={r['similarity']:.3f}, src={r['source']}) {desc_preview}...\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 3: Concept Predictor (JEPA-style)\n",
"\n",
"Using **GTE-Qwen2-1.5B-instruct** - latest GTE model with instruction following.\n",
"This encodes natural language queries and predicts the concept embedding."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConceptPredictor(nn.Module):\n",
" \"\"\"\n",
" JEPA-style concept predictor.\n",
" Given a text query, predicts the concept embedding of the target code.\n",
" \n",
" Architecture:\n",
" - Text encoder (frozen): GTE-Qwen2 (latest MTEB leader)\n",
" - Projection head (trainable): Map text embedding to code concept space\n",
" \"\"\"\n",
" \n",
" def __init__(\n",
" self, \n",
" text_encoder_name: str = \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\",\n",
" concept_dim: int = 2048,\n",
" hidden_dim: int = 2048\n",
" ):\n",
" \"\"\"\n",
" Initialize with state-of-the-art text encoder (Jan 2026).\n",
" \n",
" Alternatives:\n",
" - \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\" (1.5B, best quality)\n",
" - \"Alibaba-NLP/gte-large-en-v1.5\" (434M, faster)\n",
" - \"BAAI/bge-m3\" (multilingual)\n",
" \"\"\"\n",
" super().__init__()\n",
" \n",
" # Text encoder (frozen) - use sentence-transformers for easy loading\n",
" print(f\"Loading text encoder: {text_encoder_name}\")\n",
" try:\n",
" self.text_encoder = SentenceTransformer(text_encoder_name, trust_remote_code=True)\n",
" except:\n",
" # Fallback to a reliable model\n",
" print(\" Falling back to gte-large-en-v1.5\")\n",
" text_encoder_name = \"Alibaba-NLP/gte-large-en-v1.5\"\n",
" self.text_encoder = SentenceTransformer(text_encoder_name, trust_remote_code=True)\n",
" \n",
" self.text_encoder.to(device)\n",
" text_dim = self.text_encoder.get_sentence_embedding_dimension()\n",
" print(f\"Text embedding dim: {text_dim}\")\n",
" \n",
" # Freeze text encoder\n",
" for param in self.text_encoder.parameters():\n",
" param.requires_grad = False\n",
" \n",
" # Projection head (trainable) - maps text space to code concept space\n",
" self.projector = nn.Sequential(\n",
" nn.Linear(text_dim, hidden_dim),\n",
" nn.GELU(),\n",
" nn.LayerNorm(hidden_dim),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_dim, hidden_dim),\n",
" nn.GELU(),\n",
" nn.LayerNorm(hidden_dim),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(hidden_dim, concept_dim)\n",
" ).to(device)\n",
" \n",
" self.concept_dim = concept_dim\n",
" self.text_dim = text_dim\n",
" \n",
" # Count parameters\n",
" trainable = sum(p.numel() for p in self.projector.parameters())\n",
" print(f\"Trainable parameters: {trainable:,}\")\n",
" \n",
" def encode_text(self, texts: List[str]) -> torch.Tensor:\n",
" \"\"\"Encode text queries using frozen encoder.\"\"\"\n",
" with torch.no_grad():\n",
" embeddings = self.text_encoder.encode(\n",
" texts, \n",
" convert_to_tensor=True,\n",
" device=device,\n",
" show_progress_bar=False\n",
" )\n",
" return embeddings\n",
" \n",
" def forward(self, texts: List[str]) -> torch.Tensor:\n",
" \"\"\"Predict concept embeddings from text queries.\"\"\"\n",
" text_embeddings = self.encode_text(texts)\n",
" concept_embeddings = self.projector(text_embeddings)\n",
" return F.normalize(concept_embeddings, p=2, dim=-1)\n",
" \n",
" def predict(self, query: str) -> torch.Tensor:\n",
" \"\"\"Predict concept embedding for a single query.\"\"\"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" return self.forward([query])[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize concept predictor\n",
"concept_predictor = ConceptPredictor(concept_dim=concept_encoder.embed_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 4: Training the Concept Predictor\n",
"\n",
"We train with **InfoNCE loss** (like VL-JEPA and CLIP):\n",
"- Positive: (query, correct_code_embedding)\n",
"- Negatives: (query, other_code_embeddings in batch)\n",
"\n",
"The predictor learns to map natural language queries to the code concept space."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConceptDataset(Dataset):\n",
" \"\"\"Dataset of (description, code) pairs.\"\"\"\n",
" \n",
" def __init__(self, descriptions: List[str], codes: List[str]):\n",
" # Filter pairs where both exist\n",
" self.pairs = [\n",
" (d, c) for d, c in zip(descriptions, codes) \n",
" if d and c and len(d.strip()) > 5 and len(c.strip()) > 10\n",
" ]\n",
" print(f\"Dataset: {len(self.pairs)} valid pairs\")\n",
" \n",
" def __len__(self):\n",
" return len(self.pairs)\n",
" \n",
" def __getitem__(self, idx):\n",
" desc, code = self.pairs[idx]\n",
" return {\"description\": desc, \"code\": code}\n",
"\n",
"\n",
"def info_nce_loss(\n",
" predicted: torch.Tensor,\n",
" target: torch.Tensor,\n",
" temperature: float = 0.07\n",
") -> torch.Tensor:\n",
" \"\"\"\n",
" InfoNCE contrastive loss (same as CLIP/VL-JEPA).\n",
" \"\"\"\n",
" predicted = F.normalize(predicted, p=2, dim=-1)\n",
" target = F.normalize(target, p=2, dim=-1)\n",
" \n",
" logits = (predicted @ target.T) / temperature\n",
" labels = torch.arange(len(predicted), device=predicted.device)\n",
" \n",
" loss_p2t = F.cross_entropy(logits, labels)\n",
" loss_t2p = F.cross_entropy(logits.T, labels)\n",
" \n",
" return (loss_p2t + loss_t2p) / 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_concept_predictor(\n",
" predictor: ConceptPredictor,\n",
" encoder: ConceptEncoder,\n",
" descriptions: List[str],\n",
" codes: List[str],\n",
" epochs: int = 15,\n",
" batch_size: int = 32,\n",
" lr: float = 2e-4,\n",
" warmup_ratio: float = 0.1\n",
"):\n",
" \"\"\"\n",
" Train the concept predictor with InfoNCE loss.\n",
" \"\"\"\n",
" dataset = ConceptDataset(descriptions, codes)\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)\n",
" \n",
" optimizer = torch.optim.AdamW(predictor.projector.parameters(), lr=lr, weight_decay=0.01)\n",
" \n",
" total_steps = epochs * len(dataloader)\n",
" warmup_steps = int(total_steps * warmup_ratio)\n",
" \n",
" def lr_lambda(step):\n",
" if step < warmup_steps:\n",
" return step / warmup_steps\n",
" return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))\n",
" \n",
" scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n",
" \n",
" predictor.train()\n",
" best_loss = float('inf')\n",
" \n",
" for epoch in range(epochs):\n",
" total_loss = 0\n",
" num_batches = 0\n",
" \n",
" pbar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{epochs}\")\n",
" for batch in pbar:\n",
" with torch.no_grad():\n",
" target_embeddings = encoder.encode_batch(batch[\"code\"]).to(device)\n",
" \n",
" predicted_embeddings = predictor(batch[\"description\"])\n",
" loss = info_nce_loss(predicted_embeddings, target_embeddings)\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(predictor.projector.parameters(), 1.0)\n",
" optimizer.step()\n",
" scheduler.step()\n",
" \n",
" total_loss += loss.item()\n",
" num_batches += 1\n",
" pbar.set_postfix({\"loss\": f\"{loss.item():.4f}\", \"lr\": f\"{scheduler.get_last_lr()[0]:.2e}\"})\n",
" \n",
" avg_loss = total_loss / num_batches\n",
" print(f\"Epoch {epoch+1}: avg_loss = {avg_loss:.4f}\")\n",
" \n",
" if avg_loss < best_loss:\n",
" best_loss = avg_loss\n",
" \n",
" predictor.eval()\n",
" print(f\"\\nTraining complete! Best loss: {best_loss:.4f}\")\n",
" return predictor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train the concept predictor\n",
"print(\"Training concept predictor...\")\n",
"print(\"=\" * 50)\n",
"\n",
"concept_predictor = train_concept_predictor(\n",
" concept_predictor,\n",
" concept_encoder,\n",
" all_descriptions,\n",
" all_codes,\n",
" epochs=15,\n",
" batch_size=32\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test the trained predictor\n",
"print(\"=\" * 60)\n",
"print(\"Testing trained concept predictor\")\n",
"print(\"=\" * 60)\n",
"\n",
"test_queries = [\n",
" \"write a function to compute fibonacci numbers\",\n",
" \"implement binary search algorithm\",\n",
" \"check if a string is a valid palindrome\",\n",
" \"find the maximum element in a list\",\n",
" \"parse a JSON file and extract data\",\n",
" \"implement an LRU cache\"\n",
"]\n",
"\n",
"for query in test_queries:\n",
" print(f\"\\nQuery: '{query}'\")\n",
" print(\"-\" * 40)\n",
" \n",
" predicted_embedding = concept_predictor.predict(query)\n",
" results = concept_bank.search(predicted_embedding, k=2)\n",
" \n",
" for i, r in enumerate(results):\n",
" desc_preview = r['description'][:50].replace('\\n', ' ') if r['description'] else \"(no desc)\"\n",
" print(f\" [{i+1}] (sim={r['similarity']:.3f}, src={r['source']}) {desc_preview}...\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 5: Concept-Conditioned Code Generation\n",
"\n",
"Using **Qwen3-Coder-30B-A3B-Instruct** - latest Qwen3 Coder MoE model (Dec 2025):\n",
"- 30B total params, 3B active (MoE)\n",
"- Best open-source code model as of Jan 2026\n",
"- Excellent instruction following"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConceptFirstCodeGenerator:\n",
" \"\"\"\n",
" Concept-First Code Generation Pipeline.\n",
" \n",
" 1. Predict concept embedding from query (JEPA-style)\n",
" 2. Retrieve similar code examples from concept bank\n",
" 3. Generate code conditioned on examples (Qwen3-Coder)\n",
" \"\"\"\n",
" \n",
" def __init__(\n",
" self,\n",
" concept_predictor: ConceptPredictor,\n",
" concept_bank: ConceptBank,\n",
" llm_name: str = \"Qwen/Qwen3-Coder-30B-A3B-Instruct\",\n",
" num_examples: int = 3,\n",
" use_4bit: bool = True\n",
" ):\n",
" \"\"\"\n",
" Initialize with latest Qwen3 Coder model (Jan 2026).\n",
" \n",
" LLM Options:\n",
" - \"Qwen/Qwen3-Coder-30B-A3B-Instruct\" (30B MoE, 3B active, SOTA)\n",
" - \"Qwen/Qwen2.5-Coder-7B-Instruct\" (7B, good fallback)\n",
" - \"Qwen/Qwen2.5-Coder-32B-Instruct\" (32B, if you have A100)\n",
" \"\"\"\n",
" self.concept_predictor = concept_predictor\n",
" self.concept_bank = concept_bank\n",
" self.num_examples = num_examples\n",
" \n",
" print(f\"Loading LLM: {llm_name}\")\n",
" \n",
" if use_4bit:\n",
" bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" bnb_4bit_use_double_quant=True\n",
" )\n",
" self.llm = AutoModelForCausalLM.from_pretrained(\n",
" llm_name,\n",
" quantization_config=bnb_config,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
" attn_implementation=\"flash_attention_2\" if torch.cuda.is_available() else None\n",
" )\n",
" else:\n",
" self.llm = AutoModelForCausalLM.from_pretrained(\n",
" llm_name,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True\n",
" )\n",
" \n",
" self.tokenizer = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True)\n",
" if self.tokenizer.pad_token is None:\n",
" self.tokenizer.pad_token = self.tokenizer.eos_token\n",
" \n",
" print(f\"LLM loaded successfully\")\n",
" \n",
" def retrieve_examples(self, query: str) -> List[Dict]:\n",
" \"\"\"Retrieve relevant code examples using predicted concept.\"\"\"\n",
" concept_embedding = self.concept_predictor.predict(query)\n",
" examples = self.concept_bank.search(concept_embedding, k=self.num_examples)\n",
" return examples\n",
" \n",
" def build_prompt(self, query: str, examples: List[Dict]) -> str:\n",
" \"\"\"Build few-shot prompt with retrieved examples.\"\"\"\n",
" prompt_parts = [\n",
" \"You are an expert Python programmer. Write clean, efficient, and well-documented code.\",\n",
" \"\",\n",
" \"Here are some similar code examples for reference:\"\n",
" ]\n",
" \n",
" for i, ex in enumerate(examples):\n",
" desc = ex['description'][:150] if ex['description'] else \"(utility function)\"\n",
" code = ex['code'][:500]\n",
" prompt_parts.append(f\"\\n### Example {i+1}: {desc}\")\n",
" prompt_parts.append(f\"```python\\n{code}\\n```\")\n",
" \n",
" prompt_parts.extend([\n",
" \"\",\n",
" f\"### Task: {query}\",\n",
" \"\",\n",
" \"Write the Python code:\",\n",
" \"```python\"\n",
" ])\n",
" \n",
" return \"\\n\".join(prompt_parts)\n",
" \n",
" def generate(\n",
" self, \n",
" query: str, \n",
" max_new_tokens: int = 512,\n",
" temperature: float = 0.2,\n",
" show_examples: bool = False\n",
" ) -> Dict:\n",
" \"\"\"\n",
" Generate code using concept-first approach.\n",
" \"\"\"\n",
" examples = self.retrieve_examples(query)\n",
" \n",
" if show_examples:\n",
" print(\"Retrieved concept-matched examples:\")\n",
" for i, ex in enumerate(examples):\n",
" desc = ex['description'][:40].replace('\\n', ' ') if ex['description'] else \"(no desc)\"\n",
" print(f\" [{i+1}] (sim={ex['similarity']:.3f}, src={ex['source']}) {desc}...\")\n",
" \n",
" prompt = self.build_prompt(query, examples)\n",
" \n",
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
" text = self.tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
" \n",
" inputs = self.tokenizer(text, return_tensors=\"pt\").to(self.llm.device)\n",
" \n",
" with torch.no_grad():\n",
" outputs = self.llm.generate(\n",
" **inputs,\n",
" max_new_tokens=max_new_tokens,\n",
" temperature=temperature,\n",
" do_sample=temperature > 0,\n",
" top_p=0.95,\n",
" pad_token_id=self.tokenizer.pad_token_id,\n",
" eos_token_id=self.tokenizer.eos_token_id\n",
" )\n",
" \n",
" generated = self.tokenizer.decode(\n",
" outputs[0][inputs[\"input_ids\"].shape[1]:], \n",
" skip_special_tokens=True\n",
" )\n",
" \n",
" if \"```\" in generated:\n",
" code = generated.split(\"```\")[0].strip()\n",
" else:\n",
" code = generated.strip()\n",
" \n",
" return {\n",
" \"code\": code,\n",
" \"examples\": examples,\n",
" \"concept_similarities\": [ex[\"similarity\"] for ex in examples]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the concept-first generator\n",
"# Use Qwen2.5-Coder-7B for T4, Qwen3-Coder-30B-A3B for A100\n",
"generator = ConceptFirstCodeGenerator(\n",
" concept_predictor=concept_predictor,\n",
" concept_bank=concept_bank,\n",
" llm_name=\"Qwen/Qwen2.5-Coder-7B-Instruct\", # Use Qwen3-Coder-30B-A3B if you have A100\n",
" num_examples=3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test generation!\n",
"print(\"=\" * 60)\n",
"print(\"CONCEPT-FIRST CODE GENERATION (Jan 2026)\")\n",
"print(\"=\" * 60)\n",
"\n",
"test_queries = [\n",
" \"write a function to compute the nth fibonacci number efficiently using memoization\",\n",
" \"implement a function to check if a number is prime\",\n",
" \"write a function to find all permutations of a string\",\n",
" \"implement a binary search tree with insert, search, and delete methods\"\n",
"]\n",
"\n",
"for query in test_queries:\n",
" print(f\"\\n{'='*60}\")\n",
" print(f\"Query: {query}\")\n",
" print(\"=\"*60)\n",
" \n",
" result = generator.generate(query, show_examples=True)\n",
" \n",
" print(f\"\\nGenerated Code:\")\n",
" print(\"-\" * 40)\n",
" print(result[\"code\"][:800])\n",
" print(\"-\" * 40)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 6: Comparison - Concept-First vs Direct Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_direct(generator, query: str, max_new_tokens: int = 512) -> str:\n",
" \"\"\"Generate code directly without concept guidance.\"\"\"\n",
" prompt = f\"\"\"You are an expert Python programmer. Write clean, efficient, and well-documented code.\n",
"\n",
"### Task: {query}\n",
"\n",
"Write the Python code:\n",
"```python\"\"\"\n",
" \n",
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
" text = generator.tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
" \n",
" inputs = generator.tokenizer(text, return_tensors=\"pt\").to(generator.llm.device)\n",
" \n",
" with torch.no_grad():\n",
" outputs = generator.llm.generate(\n",
" **inputs,\n",
" max_new_tokens=max_new_tokens,\n",
" temperature=0.2,\n",
" do_sample=True,\n",
" top_p=0.95,\n",
" pad_token_id=generator.tokenizer.pad_token_id\n",
" )\n",
" \n",
" generated = generator.tokenizer.decode(\n",
" outputs[0][inputs[\"input_ids\"].shape[1]:], \n",
" skip_special_tokens=True\n",
" )\n",
" \n",
" if \"```\" in generated:\n",
" return generated.split(\"```\")[0].strip()\n",
" return generated.strip()\n",
"\n",
"\n",
"# Compare approaches\n",
"print(\"=\" * 70)\n",
"print(\"COMPARISON: Concept-First vs Direct Generation\")\n",
"print(\"=\" * 70)\n",
"\n",
"comparison_queries = [\n",
" \"implement merge sort algorithm\",\n",
" \"write a function to validate an email address using regex\",\n",
" \"implement an LRU cache with O(1) get and put operations\"\n",
"]\n",
"\n",
"for query in comparison_queries:\n",
" print(f\"\\n{'='*70}\")\n",
" print(f\"Query: {query}\")\n",
" print(\"=\"*70)\n",
" \n",
" print(\"\\n[CONCEPT-FIRST] - Uses predicted concept to retrieve examples\")\n",
" result = generator.generate(query, show_examples=True)\n",
" print(f\"\\nGenerated:\")\n",
" print(result[\"code\"][:500])\n",
" \n",
" print(\"\\n\" + \"-\"*70)\n",
" print(\"[DIRECT] - No concept guidance, just the query\")\n",
" direct_code = generate_direct(generator, query)\n",
" print(direct_code[:500])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 7: Concept Embedding Visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -q matplotlib scikit-learn umap-learn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"\n",
"# Get embeddings for various queries grouped by concept\n",
"query_categories = {\n",
" \"recursion\": [\n",
" \"compute fibonacci recursively\",\n",
" \"calculate factorial\",\n",
" \"recursive tree traversal\",\n",
" \"recursive binary search\"\n",
" ],\n",
" \"sorting\": [\n",
" \"implement quicksort\",\n",
" \"bubble sort algorithm\",\n",
" \"merge sort implementation\",\n",
" \"heap sort\"\n",
" ],\n",
" \"strings\": [\n",
" \"reverse a string\",\n",
" \"check palindrome\",\n",
" \"count character frequency\",\n",
" \"find longest substring\"\n",
" ],\n",
" \"data_structures\": [\n",
" \"implement linked list\",\n",
" \"binary search tree\",\n",
" \"hash table implementation\",\n",
" \"stack using array\"\n",
" ]\n",
"}\n",
"\n",
"# Collect embeddings\n",
"all_queries = []\n",
"all_labels = []\n",
"all_embeddings = []\n",
"\n",
"for category, queries in query_categories.items():\n",
" for q in queries:\n",
" all_queries.append(q)\n",
" all_labels.append(category)\n",
" emb = concept_predictor.predict(q).cpu().numpy()\n",
" all_embeddings.append(emb)\n",
"\n",
"all_embeddings = np.stack(all_embeddings)\n",
"\n",
"# t-SNE visualization\n",
"tsne = TSNE(n_components=2, random_state=42, perplexity=5)\n",
"embeddings_2d = tsne.fit_transform(all_embeddings)\n",
"\n",
"# Plot\n",
"plt.figure(figsize=(12, 10))\n",
"colors = {\n",
" \"recursion\": \"#e41a1c\", \n",
" \"sorting\": \"#377eb8\", \n",
" \"strings\": \"#4daf4a\", \n",
" \"data_structures\": \"#984ea3\"\n",
"}\n",
"\n",
"for i, (q, label) in enumerate(zip(all_queries, all_labels)):\n",
" plt.scatter(embeddings_2d[i, 0], embeddings_2d[i, 1], \n",
" c=colors[label], s=150, alpha=0.7, edgecolors='white', linewidth=1)\n",
" plt.annotate(q[:20], (embeddings_2d[i, 0]+0.5, embeddings_2d[i, 1]+0.5), fontsize=8)\n",
"\n",
"for label, color in colors.items():\n",
" plt.scatter([], [], c=color, label=label.replace('_', ' ').title(), s=150)\n",
"plt.legend(loc='upper right', fontsize=10)\n",
"\n",
"plt.title(\"Concept Space Visualization (t-SNE)\\nSimilar coding concepts cluster together\", fontsize=14)\n",
"plt.xlabel(\"Dimension 1\", fontsize=12)\n",
"plt.ylabel(\"Dimension 2\", fontsize=12)\n",
"plt.tight_layout()\n",
"plt.savefig(\"concept_space.png\", dpi=150)\n",
"plt.show()\n",
"\n",
"print(\"\\nSaved visualization to concept_space.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 8: Save Models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save concept predictor\n",
"torch.save({\n",
" \"projector_state_dict\": concept_predictor.projector.state_dict(),\n",
" \"concept_dim\": concept_predictor.concept_dim,\n",
" \"text_dim\": concept_predictor.text_dim,\n",
"}, \"concept_predictor.pt\")\n",
"\n",
"# Save concept bank embeddings\n",
"torch.save({\n",
" \"embeddings\": concept_bank.embeddings,\n",
" \"codes\": concept_bank.codes,\n",
" \"descriptions\": concept_bank.descriptions,\n",
" \"sources\": concept_bank.sources\n",
"}, \"concept_bank.pt\")\n",
"\n",
"print(\"Models saved!\")\n",
"print(\" - concept_predictor.pt\")\n",
"print(\" - concept_bank.pt\")\n",
"print(f\"\\nConcept bank: {len(concept_bank.codes)} concepts\")\n",
"print(f\"Embedding dim: {concept_bank.embeddings.shape[1]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"We implemented a **Concept-First Code Generation** pipeline inspired by VL-JEPA:\n",
"\n",
"### Models Used (January 2026 - Latest)\n",
"\n",
"| Component | Model | Size | Notes |\n",
"|-----------|-------|------|-------|\n",
"| Code Encoder | `Salesforce/SFR-Embedding-Code-2B_R` | 2B | SOTA CoIR 67.4 |\n",
"| Text Encoder | `Alibaba-NLP/gte-Qwen2-1.5B-instruct` | 1.5B | Latest GTE |\n",
"| Code LLM | `Qwen/Qwen3-Coder-30B-A3B-Instruct` | 30B (3B active) | Latest MoE coder |\n",
"\n",
"### Datasets (2026)\n",
"- MBPP (google-research-datasets)\n",
"- HumanEval (OpenAI)\n",
"- CodeSearchNet Python\n",
"- Evol-Instruct-Code-80k\n",
"- Magicoder-OSS-Instruct-75K\n",
"\n",
"### Key Insights\n",
"\n",
"1. **Concept prediction** helps form the \"bigger picture\" before generating tokens\n",
"2. **SFR-Embedding-Code** provides SOTA code embeddings for retrieval\n",
"3. **Semantic clustering** in concept space groups related patterns\n",
"4. **Qwen3-Coder MoE** gives best generation quality with efficient inference\n",
"\n",
"### Next Steps\n",
"\n",
"1. **Hierarchical concepts**: Program -> Function -> Block level\n",
"2. **Integration with RLM**: Use concepts as working memory\n",
"3. **Distillation to Distillix**: Train small BitNet to predict concepts\n",
"4. **Evaluation**: HumanEval, MBPP pass@1 comparison"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment