Last active
January 18, 2026 17:51
-
-
Save rileyseaburg/85c01f8615e2f82b35404ebfe08f6869 to your computer and use it in GitHub Desktop.
Concept-First Code Generation: JEPA-style concept prediction for code (VL-JEPA inspired)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Concept-First Code Generation\n", | |
| "\n", | |
| "**Inspired by VL-JEPA**: Predict concept embeddings first, then generate code conditioned on them.\n", | |
| "\n", | |
| "## The Idea\n", | |
| "\n", | |
| "Traditional autoregressive models predict tokens one at a time, which can lead to:\n", | |
| "- Losing coherence over long generations\n", | |
| "- Hallucinating APIs\n", | |
| "- Repetition loops\n", | |
| "\n", | |
| "**Concept-First** approach:\n", | |
| "1. **Concept Encoder**: Encode code snippets into semantic embeddings\n", | |
| "2. **Concept Predictor**: Given a query, predict what the code embedding should look like\n", | |
| "3. **Concept-Conditioned Generation**: Generate code guided by the predicted concept\n", | |
| "\n", | |
| "```\n", | |
| "Query: \"Write fibonacci\" \n", | |
| " ↓\n", | |
| "Concept Predictor → [0.23, -0.87, ...] (embedding)\n", | |
| " ↓\n", | |
| "Retrieve similar: [\"def fib(n): ...\", \"def factorial(n): ...\"]\n", | |
| " ↓\n", | |
| "Conditioned Generation → \"def fibonacci(n):\\n if n <= 1: ...\"\n", | |
| "```\n", | |
| "\n", | |
| "## Models Used (January 2026 - Latest)\n", | |
| "\n", | |
| "| Component | Model | Why |\n", | |
| "|-----------|-------|-----|\n", | |
| "| **Code Encoder** | `Salesforce/SFR-Embedding-Code-2B_R` | SOTA code embeddings (CoIR: 67.4), 2B params |\n", | |
| "| **Text Encoder** | `Alibaba-NLP/gte-Qwen2-1.5B-instruct` | Latest GTE with instruction support |\n", | |
| "| **Code LLM** | `Qwen/Qwen3-Coder-30B-A3B-Instruct` | Latest Qwen3 Coder MoE (Jan 2026) |\n", | |
| "| **Dataset** | `bigcode/the-stack-v2` + `MBPP` + `HumanEval` + `Evol-Instruct` | Diverse, high-quality code |" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Setup" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Install dependencies (latest versions as of Jan 2026)\n", | |
| "!pip install -q --upgrade transformers>=4.57.0 datasets>=3.0.0 torch>=2.5.0 \n", | |
| "!pip install -q --upgrade sentence-transformers>=3.3.0 accelerate>=1.2.0 bitsandbytes>=0.45.0\n", | |
| "!pip install -q --upgrade huggingface_hub einops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "from torch.utils.data import DataLoader, Dataset\n", | |
| "from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig\n", | |
| "from sentence_transformers import SentenceTransformer\n", | |
| "from datasets import load_dataset, concatenate_datasets\n", | |
| "import numpy as np\n", | |
| "from typing import List, Dict, Tuple, Optional\n", | |
| "import json\n", | |
| "from tqdm.auto import tqdm\n", | |
| "import warnings\n", | |
| "warnings.filterwarnings('ignore')\n", | |
| "\n", | |
| "# Check GPU\n", | |
| "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
| "print(f\"Using device: {device}\")\n", | |
| "if torch.cuda.is_available():\n", | |
| " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", | |
| " print(f\"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n", | |
| " \n", | |
| "# Print versions\n", | |
| "import transformers, datasets, sentence_transformers\n", | |
| "print(f\"\\nLibrary versions (Jan 2026):\")\n", | |
| "print(f\" transformers: {transformers.__version__}\")\n", | |
| "print(f\" datasets: {datasets.__version__}\")\n", | |
| "print(f\" sentence-transformers: {sentence_transformers.__version__}\")\n", | |
| "print(f\" torch: {torch.__version__}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 1: Concept Encoder (SFR-Embedding-Code-2B)\n", | |
| "\n", | |
| "We use **Salesforce SFR-Embedding-Code-2B** - the current SOTA for code embeddings.\n", | |
| "- CoIR benchmark: 67.4 NDCG@10 (best in class)\n", | |
| "- Supports code-to-code and text-to-code retrieval\n", | |
| "- 2B parameters, 32K context length" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConceptEncoder:\n", | |
| " \"\"\"\n", | |
| " Encodes code snippets into semantic concept embeddings.\n", | |
| " Uses SFR-Embedding-Code-2B (SOTA for code embeddings, Jan 2026).\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(self, model_name: str = \"Salesforce/SFR-Embedding-Code-2B_R\"):\n", | |
| " \"\"\"\n", | |
| " Initialize with Salesforce SFR-Embedding-Code.\n", | |
| " \n", | |
| " Alternatives (2026):\n", | |
| " - \"Salesforce/SFR-Embedding-Code-2B_R\" (2B, SOTA CoIR 67.4)\n", | |
| " - \"Salesforce/SFR-Embedding-Code-400M_R\" (400M, faster)\n", | |
| " - \"jinaai/jina-embeddings-v3\" (if available)\n", | |
| " \"\"\"\n", | |
| " print(f\"Loading concept encoder: {model_name}\")\n", | |
| " self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)\n", | |
| " self.model.to(device)\n", | |
| " self.model.eval()\n", | |
| " \n", | |
| " # SFR-Embedding-Code uses 2048-dim embeddings\n", | |
| " self.embed_dim = 2048\n", | |
| " self.max_length = 8192 # Can go up to 32K but 8K is efficient\n", | |
| " \n", | |
| " print(f\"Embedding dimension: {self.embed_dim}\")\n", | |
| " print(f\"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}\")\n", | |
| " \n", | |
| " @torch.no_grad()\n", | |
| " def encode(self, code: str) -> torch.Tensor:\n", | |
| " \"\"\"Encode a single code snippet.\"\"\"\n", | |
| " embeddings = self.model.encode_corpus([code], max_length=self.max_length)\n", | |
| " embeddings = F.normalize(embeddings, p=2, dim=-1)\n", | |
| " return embeddings[0]\n", | |
| " \n", | |
| " @torch.no_grad()\n", | |
| " def encode_batch(self, codes: List[str], batch_size: int = 16) -> torch.Tensor:\n", | |
| " \"\"\"Encode multiple code snippets.\"\"\"\n", | |
| " all_embeddings = []\n", | |
| " \n", | |
| " for i in tqdm(range(0, len(codes), batch_size), desc=\"Encoding\", leave=False):\n", | |
| " batch = codes[i:i+batch_size]\n", | |
| " embeddings = self.model.encode_corpus(batch, max_length=self.max_length)\n", | |
| " embeddings = F.normalize(embeddings, p=2, dim=-1)\n", | |
| " all_embeddings.append(embeddings.cpu())\n", | |
| " \n", | |
| " return torch.cat(all_embeddings, dim=0)\n", | |
| " \n", | |
| " @torch.no_grad()\n", | |
| " def encode_query(self, query: str, instruction: str = \"Given Code or Text, retrieve relevant code\") -> torch.Tensor:\n", | |
| " \"\"\"Encode a text query with instruction (for text-to-code retrieval).\"\"\"\n", | |
| " embeddings = self.model.encode_queries([query], instruction=instruction, max_length=self.max_length)\n", | |
| " embeddings = F.normalize(embeddings, p=2, dim=-1)\n", | |
| " return embeddings[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Initialize concept encoder\n", | |
| "# Use 400M version for faster iteration, 2B for best quality\n", | |
| "concept_encoder = ConceptEncoder(model_name=\"Salesforce/SFR-Embedding-Code-400M_R\")\n", | |
| "\n", | |
| "# Test it with similar code patterns\n", | |
| "test_codes = [\n", | |
| " \"def fibonacci(n):\\n if n <= 1:\\n return n\\n return fibonacci(n-1) + fibonacci(n-2)\",\n", | |
| " \"def factorial(n):\\n if n <= 1:\\n return 1\\n return n * factorial(n-1)\",\n", | |
| " \"def bubble_sort(arr):\\n for i in range(len(arr)):\\n for j in range(len(arr)-1):\\n if arr[j] > arr[j+1]:\\n arr[j], arr[j+1] = arr[j+1], arr[j]\",\n", | |
| " \"def binary_search(arr, x):\\n low, high = 0, len(arr)-1\\n while low <= high:\\n mid = (low + high) // 2\\n if arr[mid] == x:\\n return mid\"\n", | |
| "]\n", | |
| "\n", | |
| "embeddings = concept_encoder.encode_batch(test_codes)\n", | |
| "print(f\"Embeddings shape: {embeddings.shape}\")\n", | |
| "\n", | |
| "# Compute similarity matrix\n", | |
| "similarity = embeddings @ embeddings.T\n", | |
| "print(\"\\nSimilarity matrix (recursive functions should cluster together):\")\n", | |
| "labels = [\"fibonacci\", \"factorial\", \"bubble_sort\", \"binary_search\"]\n", | |
| "print(f\"{'':15}\", end=\"\")\n", | |
| "for l in labels:\n", | |
| " print(f\"{l:15}\", end=\"\")\n", | |
| "print()\n", | |
| "for i, l in enumerate(labels):\n", | |
| " print(f\"{l:15}\", end=\"\")\n", | |
| " for j in range(len(labels)):\n", | |
| " print(f\"{similarity[i,j]:.3f} \", end=\"\")\n", | |
| " print()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 2: Build Concept Bank from Multiple Datasets (2026)\n", | |
| "\n", | |
| "We combine multiple high-quality code datasets:\n", | |
| "1. **MBPP** - Curated Python problems with descriptions\n", | |
| "2. **HumanEval** - OpenAI's code benchmark\n", | |
| "3. **CodeSearchNet Python** - Large-scale code with docstrings\n", | |
| "4. **Evol-Instruct-Code** - High quality instruction-following code\n", | |
| "5. **The Stack v2** (sample) - Latest open-source code collection" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConceptBank:\n", | |
| " \"\"\"\n", | |
| " A searchable bank of code concepts.\n", | |
| " Maps embeddings to code snippets for retrieval.\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(self, encoder: ConceptEncoder):\n", | |
| " self.encoder = encoder\n", | |
| " self.embeddings = None # (N, embed_dim)\n", | |
| " self.codes = [] # List of code strings\n", | |
| " self.descriptions = [] # List of descriptions/docstrings\n", | |
| " self.sources = [] # Track where each example came from\n", | |
| " \n", | |
| " def add(self, codes: List[str], descriptions: List[str] = None, source: str = \"unknown\"):\n", | |
| " \"\"\"Add code snippets to the bank.\"\"\"\n", | |
| " if descriptions is None:\n", | |
| " descriptions = [\"\"] * len(codes)\n", | |
| " \n", | |
| " # Filter out empty/invalid codes\n", | |
| " valid_pairs = [(c, d) for c, d in zip(codes, descriptions) if c and len(c.strip()) > 10]\n", | |
| " if not valid_pairs:\n", | |
| " print(f\" No valid codes from {source}\")\n", | |
| " return\n", | |
| " \n", | |
| " codes, descriptions = zip(*valid_pairs)\n", | |
| " codes, descriptions = list(codes), list(descriptions)\n", | |
| " \n", | |
| " print(f\" Encoding {len(codes)} examples from {source}...\")\n", | |
| " new_embeddings = self.encoder.encode_batch(codes)\n", | |
| " \n", | |
| " if self.embeddings is None:\n", | |
| " self.embeddings = new_embeddings\n", | |
| " else:\n", | |
| " self.embeddings = torch.cat([self.embeddings, new_embeddings], dim=0)\n", | |
| " \n", | |
| " self.codes.extend(codes)\n", | |
| " self.descriptions.extend(descriptions)\n", | |
| " self.sources.extend([source] * len(codes))\n", | |
| " print(f\" Bank size: {len(self.codes)} concepts\")\n", | |
| " \n", | |
| " def search(self, query_embedding: torch.Tensor, k: int = 5) -> List[Dict]:\n", | |
| " \"\"\"Find k nearest concepts to the query embedding.\"\"\"\n", | |
| " query_embedding = query_embedding.cpu()\n", | |
| " if query_embedding.dim() == 1:\n", | |
| " query_embedding = query_embedding.unsqueeze(0)\n", | |
| " \n", | |
| " similarities = (query_embedding @ self.embeddings.T).squeeze(0)\n", | |
| " top_k = similarities.topk(min(k, len(self.codes)))\n", | |
| " \n", | |
| " results = []\n", | |
| " for idx, score in zip(top_k.indices.tolist(), top_k.values.tolist()):\n", | |
| " results.append({\n", | |
| " \"code\": self.codes[idx],\n", | |
| " \"description\": self.descriptions[idx],\n", | |
| " \"similarity\": score,\n", | |
| " \"source\": self.sources[idx]\n", | |
| " })\n", | |
| " return results\n", | |
| " \n", | |
| " def search_by_code(self, code: str, k: int = 5) -> List[Dict]:\n", | |
| " \"\"\"Find similar code snippets.\"\"\"\n", | |
| " embedding = self.encoder.encode(code)\n", | |
| " return self.search(embedding, k)\n", | |
| " \n", | |
| " def search_by_text(self, text: str, k: int = 5) -> List[Dict]:\n", | |
| " \"\"\"Find code matching a text description (uses query encoder with instruction).\"\"\"\n", | |
| " embedding = self.encoder.encode_query(text)\n", | |
| " return self.search(embedding, k)\n", | |
| " \n", | |
| " def stats(self):\n", | |
| " \"\"\"Print statistics about the concept bank.\"\"\"\n", | |
| " from collections import Counter\n", | |
| " source_counts = Counter(self.sources)\n", | |
| " print(f\"\\nConcept Bank Statistics:\")\n", | |
| " print(f\" Total concepts: {len(self.codes)}\")\n", | |
| " print(f\" Embedding dim: {self.embeddings.shape[1]}\")\n", | |
| " print(f\" Sources:\")\n", | |
| " for source, count in source_counts.most_common():\n", | |
| " print(f\" - {source}: {count}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Load multiple datasets (2026 versions)\n", | |
| "print(\"Loading code datasets (Jan 2026)...\")\n", | |
| "print(\"=\" * 50)\n", | |
| "\n", | |
| "all_codes = []\n", | |
| "all_descriptions = []\n", | |
| "all_sources = []\n", | |
| "\n", | |
| "# 1. MBPP - Curated Python problems (google-research version)\n", | |
| "print(\"\\n1. Loading MBPP...\")\n", | |
| "try:\n", | |
| " mbpp = load_dataset(\"google-research-datasets/mbpp\", \"full\", split=\"train\", trust_remote_code=True)\n", | |
| " for ex in mbpp:\n", | |
| " all_codes.append(ex[\"code\"])\n", | |
| " all_descriptions.append(ex[\"text\"])\n", | |
| " all_sources.append(\"mbpp\")\n", | |
| " print(f\" Added {len(mbpp)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Failed: {e}\")\n", | |
| "\n", | |
| "# 2. HumanEval - OpenAI benchmark (latest version)\n", | |
| "print(\"\\n2. Loading HumanEval...\")\n", | |
| "try:\n", | |
| " humaneval = load_dataset(\"openai/openai_humaneval\", split=\"test\", trust_remote_code=True)\n", | |
| " for ex in humaneval:\n", | |
| " code = ex[\"prompt\"] + ex[\"canonical_solution\"]\n", | |
| " all_codes.append(code)\n", | |
| " desc = ex[\"prompt\"].split('\"\"\"')[1] if '\"\"\"' in ex[\"prompt\"] else ex[\"entry_point\"]\n", | |
| " all_descriptions.append(desc.strip())\n", | |
| " all_sources.append(\"humaneval\")\n", | |
| " print(f\" Added {len(humaneval)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Failed: {e}\")\n", | |
| "\n", | |
| "# 3. CodeSearchNet Python - Large scale with docstrings\n", | |
| "print(\"\\n3. Loading CodeSearchNet (Python subset)...\")\n", | |
| "try:\n", | |
| " csn = load_dataset(\"code-search-net/code_search_net\", \"python\", split=\"train\", trust_remote_code=True)\n", | |
| " csn_sample = csn.shuffle(seed=42).select(range(min(5000, len(csn))))\n", | |
| " for ex in csn_sample:\n", | |
| " if ex[\"func_code_string\"] and ex[\"func_documentation_string\"]:\n", | |
| " all_codes.append(ex[\"func_code_string\"])\n", | |
| " all_descriptions.append(ex[\"func_documentation_string\"])\n", | |
| " all_sources.append(\"codesearchnet\")\n", | |
| " print(f\" Added {len(csn_sample)} examples\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Failed: {e}\")\n", | |
| "\n", | |
| "# 4. Evol-Instruct-Code - High quality instruction-following code\n", | |
| "print(\"\\n4. Loading Evol-Instruct-Code...\")\n", | |
| "try:\n", | |
| " evol = load_dataset(\"nickrosh/Evol-Instruct-Code-80k-v1\", split=\"train\", trust_remote_code=True)\n", | |
| " evol_sample = evol.shuffle(seed=42).select(range(min(3000, len(evol))))\n", | |
| " count = 0\n", | |
| " for ex in evol_sample:\n", | |
| " output = ex[\"output\"]\n", | |
| " if \"```python\" in output:\n", | |
| " code = output.split(\"```python\")[1].split(\"```\")[0].strip()\n", | |
| " if len(code) > 20:\n", | |
| " all_codes.append(code)\n", | |
| " all_descriptions.append(ex[\"instruction\"][:200])\n", | |
| " all_sources.append(\"evol-instruct\")\n", | |
| " count += 1\n", | |
| " print(f\" Added {count} examples from Evol-Instruct\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Failed: {e}\")\n", | |
| "\n", | |
| "# 5. Magicoder-OSS-Instruct - High quality code instructions (2024+)\n", | |
| "print(\"\\n5. Loading Magicoder-OSS-Instruct...\")\n", | |
| "try:\n", | |
| " magic = load_dataset(\"ise-uiuc/Magicoder-OSS-Instruct-75K\", split=\"train\", trust_remote_code=True)\n", | |
| " magic_sample = magic.shuffle(seed=42).select(range(min(2000, len(magic))))\n", | |
| " count = 0\n", | |
| " for ex in magic_sample:\n", | |
| " if \"```python\" in ex[\"solution\"]:\n", | |
| " code = ex[\"solution\"].split(\"```python\")[1].split(\"```\")[0].strip()\n", | |
| " if len(code) > 20:\n", | |
| " all_codes.append(code)\n", | |
| " all_descriptions.append(ex[\"problem\"][:200])\n", | |
| " all_sources.append(\"magicoder\")\n", | |
| " count += 1\n", | |
| " print(f\" Added {count} examples from Magicoder\")\n", | |
| "except Exception as e:\n", | |
| " print(f\" Failed: {e}\")\n", | |
| "\n", | |
| "print(f\"\\n{'='*50}\")\n", | |
| "print(f\"Total collected: {len(all_codes)} code examples\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Build concept bank\n", | |
| "print(\"\\nBuilding concept bank...\")\n", | |
| "concept_bank = ConceptBank(concept_encoder)\n", | |
| "\n", | |
| "# Add in batches by source for better tracking\n", | |
| "from collections import defaultdict\n", | |
| "by_source = defaultdict(lambda: {\"codes\": [], \"descs\": []})\n", | |
| "for code, desc, source in zip(all_codes, all_descriptions, all_sources):\n", | |
| " by_source[source][\"codes\"].append(code)\n", | |
| " by_source[source][\"descs\"].append(desc)\n", | |
| "\n", | |
| "for source, data in by_source.items():\n", | |
| " concept_bank.add(data[\"codes\"], data[\"descs\"], source=source)\n", | |
| "\n", | |
| "concept_bank.stats()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Test retrieval with text-to-code\n", | |
| "print(\"=\" * 60)\n", | |
| "print(\"Testing concept retrieval (text-to-code)\")\n", | |
| "print(\"=\" * 60)\n", | |
| "\n", | |
| "test_queries = [\n", | |
| " \"fibonacci sequence recursive implementation\",\n", | |
| " \"sort a list using quicksort\",\n", | |
| " \"check if string is palindrome\",\n", | |
| " \"binary tree traversal inorder\",\n", | |
| " \"read and parse json file\"\n", | |
| "]\n", | |
| "\n", | |
| "for query in test_queries:\n", | |
| " print(f\"\\nQuery: '{query}'\")\n", | |
| " print(\"-\" * 40)\n", | |
| " results = concept_bank.search_by_text(query, k=2)\n", | |
| " for i, r in enumerate(results):\n", | |
| " desc_preview = r['description'][:60].replace('\\n', ' ') if r['description'] else \"(no description)\"\n", | |
| " print(f\" [{i+1}] (sim={r['similarity']:.3f}, src={r['source']}) {desc_preview}...\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 3: Concept Predictor (JEPA-style)\n", | |
| "\n", | |
| "Using **GTE-Qwen2-1.5B-instruct** - latest GTE model with instruction following.\n", | |
| "This encodes natural language queries and predicts the concept embedding." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConceptPredictor(nn.Module):\n", | |
| " \"\"\"\n", | |
| " JEPA-style concept predictor.\n", | |
| " Given a text query, predicts the concept embedding of the target code.\n", | |
| " \n", | |
| " Architecture:\n", | |
| " - Text encoder (frozen): GTE-Qwen2 (latest MTEB leader)\n", | |
| " - Projection head (trainable): Map text embedding to code concept space\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(\n", | |
| " self, \n", | |
| " text_encoder_name: str = \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\",\n", | |
| " concept_dim: int = 2048,\n", | |
| " hidden_dim: int = 2048\n", | |
| " ):\n", | |
| " \"\"\"\n", | |
| " Initialize with state-of-the-art text encoder (Jan 2026).\n", | |
| " \n", | |
| " Alternatives:\n", | |
| " - \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\" (1.5B, best quality)\n", | |
| " - \"Alibaba-NLP/gte-large-en-v1.5\" (434M, faster)\n", | |
| " - \"BAAI/bge-m3\" (multilingual)\n", | |
| " \"\"\"\n", | |
| " super().__init__()\n", | |
| " \n", | |
| " # Text encoder (frozen) - use sentence-transformers for easy loading\n", | |
| " print(f\"Loading text encoder: {text_encoder_name}\")\n", | |
| " try:\n", | |
| " self.text_encoder = SentenceTransformer(text_encoder_name, trust_remote_code=True)\n", | |
| " except:\n", | |
| " # Fallback to a reliable model\n", | |
| " print(\" Falling back to gte-large-en-v1.5\")\n", | |
| " text_encoder_name = \"Alibaba-NLP/gte-large-en-v1.5\"\n", | |
| " self.text_encoder = SentenceTransformer(text_encoder_name, trust_remote_code=True)\n", | |
| " \n", | |
| " self.text_encoder.to(device)\n", | |
| " text_dim = self.text_encoder.get_sentence_embedding_dimension()\n", | |
| " print(f\"Text embedding dim: {text_dim}\")\n", | |
| " \n", | |
| " # Freeze text encoder\n", | |
| " for param in self.text_encoder.parameters():\n", | |
| " param.requires_grad = False\n", | |
| " \n", | |
| " # Projection head (trainable) - maps text space to code concept space\n", | |
| " self.projector = nn.Sequential(\n", | |
| " nn.Linear(text_dim, hidden_dim),\n", | |
| " nn.GELU(),\n", | |
| " nn.LayerNorm(hidden_dim),\n", | |
| " nn.Dropout(0.1),\n", | |
| " nn.Linear(hidden_dim, hidden_dim),\n", | |
| " nn.GELU(),\n", | |
| " nn.LayerNorm(hidden_dim),\n", | |
| " nn.Dropout(0.1),\n", | |
| " nn.Linear(hidden_dim, concept_dim)\n", | |
| " ).to(device)\n", | |
| " \n", | |
| " self.concept_dim = concept_dim\n", | |
| " self.text_dim = text_dim\n", | |
| " \n", | |
| " # Count parameters\n", | |
| " trainable = sum(p.numel() for p in self.projector.parameters())\n", | |
| " print(f\"Trainable parameters: {trainable:,}\")\n", | |
| " \n", | |
| " def encode_text(self, texts: List[str]) -> torch.Tensor:\n", | |
| " \"\"\"Encode text queries using frozen encoder.\"\"\"\n", | |
| " with torch.no_grad():\n", | |
| " embeddings = self.text_encoder.encode(\n", | |
| " texts, \n", | |
| " convert_to_tensor=True,\n", | |
| " device=device,\n", | |
| " show_progress_bar=False\n", | |
| " )\n", | |
| " return embeddings\n", | |
| " \n", | |
| " def forward(self, texts: List[str]) -> torch.Tensor:\n", | |
| " \"\"\"Predict concept embeddings from text queries.\"\"\"\n", | |
| " text_embeddings = self.encode_text(texts)\n", | |
| " concept_embeddings = self.projector(text_embeddings)\n", | |
| " return F.normalize(concept_embeddings, p=2, dim=-1)\n", | |
| " \n", | |
| " def predict(self, query: str) -> torch.Tensor:\n", | |
| " \"\"\"Predict concept embedding for a single query.\"\"\"\n", | |
| " self.eval()\n", | |
| " with torch.no_grad():\n", | |
| " return self.forward([query])[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Initialize concept predictor\n", | |
| "concept_predictor = ConceptPredictor(concept_dim=concept_encoder.embed_dim)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 4: Training the Concept Predictor\n", | |
| "\n", | |
| "We train with **InfoNCE loss** (like VL-JEPA and CLIP):\n", | |
| "- Positive: (query, correct_code_embedding)\n", | |
| "- Negatives: (query, other_code_embeddings in batch)\n", | |
| "\n", | |
| "The predictor learns to map natural language queries to the code concept space." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConceptDataset(Dataset):\n", | |
| " \"\"\"Dataset of (description, code) pairs.\"\"\"\n", | |
| " \n", | |
| " def __init__(self, descriptions: List[str], codes: List[str]):\n", | |
| " # Filter pairs where both exist\n", | |
| " self.pairs = [\n", | |
| " (d, c) for d, c in zip(descriptions, codes) \n", | |
| " if d and c and len(d.strip()) > 5 and len(c.strip()) > 10\n", | |
| " ]\n", | |
| " print(f\"Dataset: {len(self.pairs)} valid pairs\")\n", | |
| " \n", | |
| " def __len__(self):\n", | |
| " return len(self.pairs)\n", | |
| " \n", | |
| " def __getitem__(self, idx):\n", | |
| " desc, code = self.pairs[idx]\n", | |
| " return {\"description\": desc, \"code\": code}\n", | |
| "\n", | |
| "\n", | |
| "def info_nce_loss(\n", | |
| " predicted: torch.Tensor,\n", | |
| " target: torch.Tensor,\n", | |
| " temperature: float = 0.07\n", | |
| ") -> torch.Tensor:\n", | |
| " \"\"\"\n", | |
| " InfoNCE contrastive loss (same as CLIP/VL-JEPA).\n", | |
| " \"\"\"\n", | |
| " predicted = F.normalize(predicted, p=2, dim=-1)\n", | |
| " target = F.normalize(target, p=2, dim=-1)\n", | |
| " \n", | |
| " logits = (predicted @ target.T) / temperature\n", | |
| " labels = torch.arange(len(predicted), device=predicted.device)\n", | |
| " \n", | |
| " loss_p2t = F.cross_entropy(logits, labels)\n", | |
| " loss_t2p = F.cross_entropy(logits.T, labels)\n", | |
| " \n", | |
| " return (loss_p2t + loss_t2p) / 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def train_concept_predictor(\n", | |
| " predictor: ConceptPredictor,\n", | |
| " encoder: ConceptEncoder,\n", | |
| " descriptions: List[str],\n", | |
| " codes: List[str],\n", | |
| " epochs: int = 15,\n", | |
| " batch_size: int = 32,\n", | |
| " lr: float = 2e-4,\n", | |
| " warmup_ratio: float = 0.1\n", | |
| "):\n", | |
| " \"\"\"\n", | |
| " Train the concept predictor with InfoNCE loss.\n", | |
| " \"\"\"\n", | |
| " dataset = ConceptDataset(descriptions, codes)\n", | |
| " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)\n", | |
| " \n", | |
| " optimizer = torch.optim.AdamW(predictor.projector.parameters(), lr=lr, weight_decay=0.01)\n", | |
| " \n", | |
| " total_steps = epochs * len(dataloader)\n", | |
| " warmup_steps = int(total_steps * warmup_ratio)\n", | |
| " \n", | |
| " def lr_lambda(step):\n", | |
| " if step < warmup_steps:\n", | |
| " return step / warmup_steps\n", | |
| " return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))\n", | |
| " \n", | |
| " scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n", | |
| " \n", | |
| " predictor.train()\n", | |
| " best_loss = float('inf')\n", | |
| " \n", | |
| " for epoch in range(epochs):\n", | |
| " total_loss = 0\n", | |
| " num_batches = 0\n", | |
| " \n", | |
| " pbar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{epochs}\")\n", | |
| " for batch in pbar:\n", | |
| " with torch.no_grad():\n", | |
| " target_embeddings = encoder.encode_batch(batch[\"code\"]).to(device)\n", | |
| " \n", | |
| " predicted_embeddings = predictor(batch[\"description\"])\n", | |
| " loss = info_nce_loss(predicted_embeddings, target_embeddings)\n", | |
| " \n", | |
| " optimizer.zero_grad()\n", | |
| " loss.backward()\n", | |
| " torch.nn.utils.clip_grad_norm_(predictor.projector.parameters(), 1.0)\n", | |
| " optimizer.step()\n", | |
| " scheduler.step()\n", | |
| " \n", | |
| " total_loss += loss.item()\n", | |
| " num_batches += 1\n", | |
| " pbar.set_postfix({\"loss\": f\"{loss.item():.4f}\", \"lr\": f\"{scheduler.get_last_lr()[0]:.2e}\"})\n", | |
| " \n", | |
| " avg_loss = total_loss / num_batches\n", | |
| " print(f\"Epoch {epoch+1}: avg_loss = {avg_loss:.4f}\")\n", | |
| " \n", | |
| " if avg_loss < best_loss:\n", | |
| " best_loss = avg_loss\n", | |
| " \n", | |
| " predictor.eval()\n", | |
| " print(f\"\\nTraining complete! Best loss: {best_loss:.4f}\")\n", | |
| " return predictor" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Train the concept predictor\n", | |
| "print(\"Training concept predictor...\")\n", | |
| "print(\"=\" * 50)\n", | |
| "\n", | |
| "concept_predictor = train_concept_predictor(\n", | |
| " concept_predictor,\n", | |
| " concept_encoder,\n", | |
| " all_descriptions,\n", | |
| " all_codes,\n", | |
| " epochs=15,\n", | |
| " batch_size=32\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Test the trained predictor\n", | |
| "print(\"=\" * 60)\n", | |
| "print(\"Testing trained concept predictor\")\n", | |
| "print(\"=\" * 60)\n", | |
| "\n", | |
| "test_queries = [\n", | |
| " \"write a function to compute fibonacci numbers\",\n", | |
| " \"implement binary search algorithm\",\n", | |
| " \"check if a string is a valid palindrome\",\n", | |
| " \"find the maximum element in a list\",\n", | |
| " \"parse a JSON file and extract data\",\n", | |
| " \"implement an LRU cache\"\n", | |
| "]\n", | |
| "\n", | |
| "for query in test_queries:\n", | |
| " print(f\"\\nQuery: '{query}'\")\n", | |
| " print(\"-\" * 40)\n", | |
| " \n", | |
| " predicted_embedding = concept_predictor.predict(query)\n", | |
| " results = concept_bank.search(predicted_embedding, k=2)\n", | |
| " \n", | |
| " for i, r in enumerate(results):\n", | |
| " desc_preview = r['description'][:50].replace('\\n', ' ') if r['description'] else \"(no desc)\"\n", | |
| " print(f\" [{i+1}] (sim={r['similarity']:.3f}, src={r['source']}) {desc_preview}...\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 5: Concept-Conditioned Code Generation\n", | |
| "\n", | |
| "Using **Qwen3-Coder-30B-A3B-Instruct** - latest Qwen3 Coder MoE model (Dec 2025):\n", | |
| "- 30B total params, 3B active (MoE)\n", | |
| "- Best open-source code model as of Jan 2026\n", | |
| "- Excellent instruction following" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConceptFirstCodeGenerator:\n", | |
| " \"\"\"\n", | |
| " Concept-First Code Generation Pipeline.\n", | |
| " \n", | |
| " 1. Predict concept embedding from query (JEPA-style)\n", | |
| " 2. Retrieve similar code examples from concept bank\n", | |
| " 3. Generate code conditioned on examples (Qwen3-Coder)\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " concept_predictor: ConceptPredictor,\n", | |
| " concept_bank: ConceptBank,\n", | |
| " llm_name: str = \"Qwen/Qwen3-Coder-30B-A3B-Instruct\",\n", | |
| " num_examples: int = 3,\n", | |
| " use_4bit: bool = True\n", | |
| " ):\n", | |
| " \"\"\"\n", | |
| " Initialize with latest Qwen3 Coder model (Jan 2026).\n", | |
| " \n", | |
| " LLM Options:\n", | |
| " - \"Qwen/Qwen3-Coder-30B-A3B-Instruct\" (30B MoE, 3B active, SOTA)\n", | |
| " - \"Qwen/Qwen2.5-Coder-7B-Instruct\" (7B, good fallback)\n", | |
| " - \"Qwen/Qwen2.5-Coder-32B-Instruct\" (32B, if you have A100)\n", | |
| " \"\"\"\n", | |
| " self.concept_predictor = concept_predictor\n", | |
| " self.concept_bank = concept_bank\n", | |
| " self.num_examples = num_examples\n", | |
| " \n", | |
| " print(f\"Loading LLM: {llm_name}\")\n", | |
| " \n", | |
| " if use_4bit:\n", | |
| " bnb_config = BitsAndBytesConfig(\n", | |
| " load_in_4bit=True,\n", | |
| " bnb_4bit_quant_type=\"nf4\",\n", | |
| " bnb_4bit_compute_dtype=torch.bfloat16,\n", | |
| " bnb_4bit_use_double_quant=True\n", | |
| " )\n", | |
| " self.llm = AutoModelForCausalLM.from_pretrained(\n", | |
| " llm_name,\n", | |
| " quantization_config=bnb_config,\n", | |
| " device_map=\"auto\",\n", | |
| " trust_remote_code=True,\n", | |
| " attn_implementation=\"flash_attention_2\" if torch.cuda.is_available() else None\n", | |
| " )\n", | |
| " else:\n", | |
| " self.llm = AutoModelForCausalLM.from_pretrained(\n", | |
| " llm_name,\n", | |
| " torch_dtype=torch.bfloat16,\n", | |
| " device_map=\"auto\",\n", | |
| " trust_remote_code=True\n", | |
| " )\n", | |
| " \n", | |
| " self.tokenizer = AutoTokenizer.from_pretrained(llm_name, trust_remote_code=True)\n", | |
| " if self.tokenizer.pad_token is None:\n", | |
| " self.tokenizer.pad_token = self.tokenizer.eos_token\n", | |
| " \n", | |
| " print(f\"LLM loaded successfully\")\n", | |
| " \n", | |
| " def retrieve_examples(self, query: str) -> List[Dict]:\n", | |
| " \"\"\"Retrieve relevant code examples using predicted concept.\"\"\"\n", | |
| " concept_embedding = self.concept_predictor.predict(query)\n", | |
| " examples = self.concept_bank.search(concept_embedding, k=self.num_examples)\n", | |
| " return examples\n", | |
| " \n", | |
| " def build_prompt(self, query: str, examples: List[Dict]) -> str:\n", | |
| " \"\"\"Build few-shot prompt with retrieved examples.\"\"\"\n", | |
| " prompt_parts = [\n", | |
| " \"You are an expert Python programmer. Write clean, efficient, and well-documented code.\",\n", | |
| " \"\",\n", | |
| " \"Here are some similar code examples for reference:\"\n", | |
| " ]\n", | |
| " \n", | |
| " for i, ex in enumerate(examples):\n", | |
| " desc = ex['description'][:150] if ex['description'] else \"(utility function)\"\n", | |
| " code = ex['code'][:500]\n", | |
| " prompt_parts.append(f\"\\n### Example {i+1}: {desc}\")\n", | |
| " prompt_parts.append(f\"```python\\n{code}\\n```\")\n", | |
| " \n", | |
| " prompt_parts.extend([\n", | |
| " \"\",\n", | |
| " f\"### Task: {query}\",\n", | |
| " \"\",\n", | |
| " \"Write the Python code:\",\n", | |
| " \"```python\"\n", | |
| " ])\n", | |
| " \n", | |
| " return \"\\n\".join(prompt_parts)\n", | |
| " \n", | |
| " def generate(\n", | |
| " self, \n", | |
| " query: str, \n", | |
| " max_new_tokens: int = 512,\n", | |
| " temperature: float = 0.2,\n", | |
| " show_examples: bool = False\n", | |
| " ) -> Dict:\n", | |
| " \"\"\"\n", | |
| " Generate code using concept-first approach.\n", | |
| " \"\"\"\n", | |
| " examples = self.retrieve_examples(query)\n", | |
| " \n", | |
| " if show_examples:\n", | |
| " print(\"Retrieved concept-matched examples:\")\n", | |
| " for i, ex in enumerate(examples):\n", | |
| " desc = ex['description'][:40].replace('\\n', ' ') if ex['description'] else \"(no desc)\"\n", | |
| " print(f\" [{i+1}] (sim={ex['similarity']:.3f}, src={ex['source']}) {desc}...\")\n", | |
| " \n", | |
| " prompt = self.build_prompt(query, examples)\n", | |
| " \n", | |
| " messages = [{\"role\": \"user\", \"content\": prompt}]\n", | |
| " text = self.tokenizer.apply_chat_template(\n", | |
| " messages, tokenize=False, add_generation_prompt=True\n", | |
| " )\n", | |
| " \n", | |
| " inputs = self.tokenizer(text, return_tensors=\"pt\").to(self.llm.device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " outputs = self.llm.generate(\n", | |
| " **inputs,\n", | |
| " max_new_tokens=max_new_tokens,\n", | |
| " temperature=temperature,\n", | |
| " do_sample=temperature > 0,\n", | |
| " top_p=0.95,\n", | |
| " pad_token_id=self.tokenizer.pad_token_id,\n", | |
| " eos_token_id=self.tokenizer.eos_token_id\n", | |
| " )\n", | |
| " \n", | |
| " generated = self.tokenizer.decode(\n", | |
| " outputs[0][inputs[\"input_ids\"].shape[1]:], \n", | |
| " skip_special_tokens=True\n", | |
| " )\n", | |
| " \n", | |
| " if \"```\" in generated:\n", | |
| " code = generated.split(\"```\")[0].strip()\n", | |
| " else:\n", | |
| " code = generated.strip()\n", | |
| " \n", | |
| " return {\n", | |
| " \"code\": code,\n", | |
| " \"examples\": examples,\n", | |
| " \"concept_similarities\": [ex[\"similarity\"] for ex in examples]\n", | |
| " }" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Initialize the concept-first generator\n", | |
| "# Use Qwen2.5-Coder-7B for T4, Qwen3-Coder-30B-A3B for A100\n", | |
| "generator = ConceptFirstCodeGenerator(\n", | |
| " concept_predictor=concept_predictor,\n", | |
| " concept_bank=concept_bank,\n", | |
| " llm_name=\"Qwen/Qwen2.5-Coder-7B-Instruct\", # Use Qwen3-Coder-30B-A3B if you have A100\n", | |
| " num_examples=3\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Test generation!\n", | |
| "print(\"=\" * 60)\n", | |
| "print(\"CONCEPT-FIRST CODE GENERATION (Jan 2026)\")\n", | |
| "print(\"=\" * 60)\n", | |
| "\n", | |
| "test_queries = [\n", | |
| " \"write a function to compute the nth fibonacci number efficiently using memoization\",\n", | |
| " \"implement a function to check if a number is prime\",\n", | |
| " \"write a function to find all permutations of a string\",\n", | |
| " \"implement a binary search tree with insert, search, and delete methods\"\n", | |
| "]\n", | |
| "\n", | |
| "for query in test_queries:\n", | |
| " print(f\"\\n{'='*60}\")\n", | |
| " print(f\"Query: {query}\")\n", | |
| " print(\"=\"*60)\n", | |
| " \n", | |
| " result = generator.generate(query, show_examples=True)\n", | |
| " \n", | |
| " print(f\"\\nGenerated Code:\")\n", | |
| " print(\"-\" * 40)\n", | |
| " print(result[\"code\"][:800])\n", | |
| " print(\"-\" * 40)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 6: Comparison - Concept-First vs Direct Generation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def generate_direct(generator, query: str, max_new_tokens: int = 512) -> str:\n", | |
| " \"\"\"Generate code directly without concept guidance.\"\"\"\n", | |
| " prompt = f\"\"\"You are an expert Python programmer. Write clean, efficient, and well-documented code.\n", | |
| "\n", | |
| "### Task: {query}\n", | |
| "\n", | |
| "Write the Python code:\n", | |
| "```python\"\"\"\n", | |
| " \n", | |
| " messages = [{\"role\": \"user\", \"content\": prompt}]\n", | |
| " text = generator.tokenizer.apply_chat_template(\n", | |
| " messages, tokenize=False, add_generation_prompt=True\n", | |
| " )\n", | |
| " \n", | |
| " inputs = generator.tokenizer(text, return_tensors=\"pt\").to(generator.llm.device)\n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " outputs = generator.llm.generate(\n", | |
| " **inputs,\n", | |
| " max_new_tokens=max_new_tokens,\n", | |
| " temperature=0.2,\n", | |
| " do_sample=True,\n", | |
| " top_p=0.95,\n", | |
| " pad_token_id=generator.tokenizer.pad_token_id\n", | |
| " )\n", | |
| " \n", | |
| " generated = generator.tokenizer.decode(\n", | |
| " outputs[0][inputs[\"input_ids\"].shape[1]:], \n", | |
| " skip_special_tokens=True\n", | |
| " )\n", | |
| " \n", | |
| " if \"```\" in generated:\n", | |
| " return generated.split(\"```\")[0].strip()\n", | |
| " return generated.strip()\n", | |
| "\n", | |
| "\n", | |
| "# Compare approaches\n", | |
| "print(\"=\" * 70)\n", | |
| "print(\"COMPARISON: Concept-First vs Direct Generation\")\n", | |
| "print(\"=\" * 70)\n", | |
| "\n", | |
| "comparison_queries = [\n", | |
| " \"implement merge sort algorithm\",\n", | |
| " \"write a function to validate an email address using regex\",\n", | |
| " \"implement an LRU cache with O(1) get and put operations\"\n", | |
| "]\n", | |
| "\n", | |
| "for query in comparison_queries:\n", | |
| " print(f\"\\n{'='*70}\")\n", | |
| " print(f\"Query: {query}\")\n", | |
| " print(\"=\"*70)\n", | |
| " \n", | |
| " print(\"\\n[CONCEPT-FIRST] - Uses predicted concept to retrieve examples\")\n", | |
| " result = generator.generate(query, show_examples=True)\n", | |
| " print(f\"\\nGenerated:\")\n", | |
| " print(result[\"code\"][:500])\n", | |
| " \n", | |
| " print(\"\\n\" + \"-\"*70)\n", | |
| " print(\"[DIRECT] - No concept guidance, just the query\")\n", | |
| " direct_code = generate_direct(generator, query)\n", | |
| " print(direct_code[:500])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 7: Concept Embedding Visualization" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "!pip install -q matplotlib scikit-learn umap-learn" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "from sklearn.manifold import TSNE\n", | |
| "\n", | |
| "# Get embeddings for various queries grouped by concept\n", | |
| "query_categories = {\n", | |
| " \"recursion\": [\n", | |
| " \"compute fibonacci recursively\",\n", | |
| " \"calculate factorial\",\n", | |
| " \"recursive tree traversal\",\n", | |
| " \"recursive binary search\"\n", | |
| " ],\n", | |
| " \"sorting\": [\n", | |
| " \"implement quicksort\",\n", | |
| " \"bubble sort algorithm\",\n", | |
| " \"merge sort implementation\",\n", | |
| " \"heap sort\"\n", | |
| " ],\n", | |
| " \"strings\": [\n", | |
| " \"reverse a string\",\n", | |
| " \"check palindrome\",\n", | |
| " \"count character frequency\",\n", | |
| " \"find longest substring\"\n", | |
| " ],\n", | |
| " \"data_structures\": [\n", | |
| " \"implement linked list\",\n", | |
| " \"binary search tree\",\n", | |
| " \"hash table implementation\",\n", | |
| " \"stack using array\"\n", | |
| " ]\n", | |
| "}\n", | |
| "\n", | |
| "# Collect embeddings\n", | |
| "all_queries = []\n", | |
| "all_labels = []\n", | |
| "all_embeddings = []\n", | |
| "\n", | |
| "for category, queries in query_categories.items():\n", | |
| " for q in queries:\n", | |
| " all_queries.append(q)\n", | |
| " all_labels.append(category)\n", | |
| " emb = concept_predictor.predict(q).cpu().numpy()\n", | |
| " all_embeddings.append(emb)\n", | |
| "\n", | |
| "all_embeddings = np.stack(all_embeddings)\n", | |
| "\n", | |
| "# t-SNE visualization\n", | |
| "tsne = TSNE(n_components=2, random_state=42, perplexity=5)\n", | |
| "embeddings_2d = tsne.fit_transform(all_embeddings)\n", | |
| "\n", | |
| "# Plot\n", | |
| "plt.figure(figsize=(12, 10))\n", | |
| "colors = {\n", | |
| " \"recursion\": \"#e41a1c\", \n", | |
| " \"sorting\": \"#377eb8\", \n", | |
| " \"strings\": \"#4daf4a\", \n", | |
| " \"data_structures\": \"#984ea3\"\n", | |
| "}\n", | |
| "\n", | |
| "for i, (q, label) in enumerate(zip(all_queries, all_labels)):\n", | |
| " plt.scatter(embeddings_2d[i, 0], embeddings_2d[i, 1], \n", | |
| " c=colors[label], s=150, alpha=0.7, edgecolors='white', linewidth=1)\n", | |
| " plt.annotate(q[:20], (embeddings_2d[i, 0]+0.5, embeddings_2d[i, 1]+0.5), fontsize=8)\n", | |
| "\n", | |
| "for label, color in colors.items():\n", | |
| " plt.scatter([], [], c=color, label=label.replace('_', ' ').title(), s=150)\n", | |
| "plt.legend(loc='upper right', fontsize=10)\n", | |
| "\n", | |
| "plt.title(\"Concept Space Visualization (t-SNE)\\nSimilar coding concepts cluster together\", fontsize=14)\n", | |
| "plt.xlabel(\"Dimension 1\", fontsize=12)\n", | |
| "plt.ylabel(\"Dimension 2\", fontsize=12)\n", | |
| "plt.tight_layout()\n", | |
| "plt.savefig(\"concept_space.png\", dpi=150)\n", | |
| "plt.show()\n", | |
| "\n", | |
| "print(\"\\nSaved visualization to concept_space.png\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Part 8: Save Models" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Save concept predictor\n", | |
| "torch.save({\n", | |
| " \"projector_state_dict\": concept_predictor.projector.state_dict(),\n", | |
| " \"concept_dim\": concept_predictor.concept_dim,\n", | |
| " \"text_dim\": concept_predictor.text_dim,\n", | |
| "}, \"concept_predictor.pt\")\n", | |
| "\n", | |
| "# Save concept bank embeddings\n", | |
| "torch.save({\n", | |
| " \"embeddings\": concept_bank.embeddings,\n", | |
| " \"codes\": concept_bank.codes,\n", | |
| " \"descriptions\": concept_bank.descriptions,\n", | |
| " \"sources\": concept_bank.sources\n", | |
| "}, \"concept_bank.pt\")\n", | |
| "\n", | |
| "print(\"Models saved!\")\n", | |
| "print(\" - concept_predictor.pt\")\n", | |
| "print(\" - concept_bank.pt\")\n", | |
| "print(f\"\\nConcept bank: {len(concept_bank.codes)} concepts\")\n", | |
| "print(f\"Embedding dim: {concept_bank.embeddings.shape[1]}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Summary\n", | |
| "\n", | |
| "We implemented a **Concept-First Code Generation** pipeline inspired by VL-JEPA:\n", | |
| "\n", | |
| "### Models Used (January 2026 - Latest)\n", | |
| "\n", | |
| "| Component | Model | Size | Notes |\n", | |
| "|-----------|-------|------|-------|\n", | |
| "| Code Encoder | `Salesforce/SFR-Embedding-Code-2B_R` | 2B | SOTA CoIR 67.4 |\n", | |
| "| Text Encoder | `Alibaba-NLP/gte-Qwen2-1.5B-instruct` | 1.5B | Latest GTE |\n", | |
| "| Code LLM | `Qwen/Qwen3-Coder-30B-A3B-Instruct` | 30B (3B active) | Latest MoE coder |\n", | |
| "\n", | |
| "### Datasets (2026)\n", | |
| "- MBPP (google-research-datasets)\n", | |
| "- HumanEval (OpenAI)\n", | |
| "- CodeSearchNet Python\n", | |
| "- Evol-Instruct-Code-80k\n", | |
| "- Magicoder-OSS-Instruct-75K\n", | |
| "\n", | |
| "### Key Insights\n", | |
| "\n", | |
| "1. **Concept prediction** helps form the \"bigger picture\" before generating tokens\n", | |
| "2. **SFR-Embedding-Code** provides SOTA code embeddings for retrieval\n", | |
| "3. **Semantic clustering** in concept space groups related patterns\n", | |
| "4. **Qwen3-Coder MoE** gives best generation quality with efficient inference\n", | |
| "\n", | |
| "### Next Steps\n", | |
| "\n", | |
| "1. **Hierarchical concepts**: Program -> Function -> Block level\n", | |
| "2. **Integration with RLM**: Use concepts as working memory\n", | |
| "3. **Distillation to Distillix**: Train small BitNet to predict concepts\n", | |
| "4. **Evaluation**: HumanEval, MBPP pass@1 comparison" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.11.0" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment