Created
August 11, 2025 03:23
-
-
Save jxmorris12/65a6ec4054f5ac388b0fe378057f5d7c to your computer and use it in GitHub Desktop.
label_programming_language.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline | |
| from datasets import load_dataset | |
| from collections import Counter | |
| import json | |
| import torch | |
| from tqdm import tqdm | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| PROGRAMMING_LANGUAGE_MAP = { | |
| 0: "ARM Assembly", | |
| 1: "AppleScript", | |
| 2: "C", | |
| 3: "C#", | |
| 4: "C++", | |
| 5: "COBOL", | |
| 6: "Erlang", | |
| 7: "Fortran", | |
| 8: "Go", | |
| 9: "Java", | |
| 10: "JavaScript", | |
| 11: "Kotlin", | |
| 12: "Lua", | |
| 13: "Mathematica/Wolfram Language", | |
| 14: "PHP", | |
| 15: "Pascal", | |
| 16: "Perl", | |
| 17: "PowerShell", | |
| 18: "Python", | |
| 19: "R", | |
| 20: "Ruby", | |
| 21: "Rust", | |
| 22: "Scala", | |
| 23: "Swift", | |
| 24: "Visual Basic .NET", | |
| 25: "jq" | |
| } | |
| def load_programming_language_classifier(): | |
| """Load the programming language classification pipeline.""" | |
| model_name = 'philomath-1209/programming-language-identification' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) | |
| return model, tokenizer | |
| def tokenize(text, tokenizer, max_tokens=64) -> dict[str, torch.Tensor]: | |
| """Truncate text to a maximum number of tokens.""" | |
| text = [t[:2000] for t in text] | |
| tokenizd = tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_tokens, | |
| return_tensors="pt", | |
| ) | |
| return tokenizd | |
| def main(): | |
| """Main function to classify all texts in the dataset and save results.""" | |
| print("Loading programming language classifier...") | |
| model, tokenizer = load_programming_language_classifier() | |
| print("Loading dataset...") | |
| # Login using e.g. `huggingface-cli login` to access this dataset | |
| ds = load_dataset("jxm/gpt-oss20b-samples") | |
| train_data = ds["train"] | |
| print(f"Found {len(train_data)} samples to classify...") | |
| # Counter to track programming language classifications | |
| language_counter = Counter() | |
| # Process in batches of 4096 | |
| batch_size = 8192*2 | |
| total_batches = (len(train_data) + batch_size - 1) // batch_size | |
| print(f"Processing in {total_batches} batches of {batch_size}...") | |
| for batch_idx in tqdm(range(total_batches), desc="Processing batches"): | |
| start_idx = batch_idx * batch_size | |
| end_idx = min(start_idx + batch_size, len(train_data)) | |
| batch = tokenize(train_data[start_idx:end_idx]["text"], tokenizer).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**batch) | |
| pmax = torch.softmax(outputs.logits, dim=-1).max(dim=-1) | |
| probabilities = pmax.values | |
| predictions = pmax.indices | |
| pmask = probabilities > 0.8 | |
| predictions = predictions[pmask].cpu().tolist() | |
| for p in predictions: | |
| language_counter[PROGRAMMING_LANGUAGE_MAP[p]] += 1 | |
| language_counter = dict(language_counter) | |
| output_file = "label_programming_language.json" | |
| with open(output_file, "w") as f: | |
| json.dump(language_counter, f, indent=2) | |
| print(f"Classification complete! Results saved to {output_file}") | |
| print(f"Total programming languages detected: {len(language_counter)}") | |
| print("Top 10 programming languages:") | |
| for lang, count in language_counter.most_common(10): | |
| print(f" {lang}: {count}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment