Created
July 25, 2025 16:26
-
-
Save pruksmhc/2108ebfc97878436e13c86050071f3bc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import re | |
| import pandas as pd | |
| import multiprocessing as mp | |
| import time | |
| from datasets import Dataset, load_from_disk | |
| import gc | |
| from tqdm import tqdm | |
| import psutil | |
| import os | |
| DATA_DIR = "grouped_by_repo" | |
| OUTPUT_DIR = "hf_dataset_shards_optimized_new" | |
| DATASET_NAME = "yadapruk/grouped-starcoderdata-merged" | |
| BATCH_SIZE = 10000 | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # --- Detect existing shards --- | |
| existing_shards = sorted([ | |
| int(re.search(r"shard_(\d+)", f).group(1)) | |
| for f in os.listdir(OUTPUT_DIR) | |
| if f.startswith("shard_") and os.path.isdir(os.path.join(OUTPUT_DIR, f)) | |
| ]) | |
| starting_shard_idx = int(open(os.path.join(OUTPUT_DIR, "shard_log.txt"), "r").read().strip()) | |
| already_processed_files = starting_shard_idx * BATCH_SIZE | |
| print(f"[INFO] Resuming from shard {starting_shard_idx}") | |
| print(f"[INFO] Skipping first {already_processed_files} examples") | |
| def print_mem_usage(): | |
| process = psutil.Process(os.getpid()) | |
| mem_mb = process.memory_info().rss / 1e6 # Resident Set Size in MB | |
| print(f"[MEM] Memory usage: {mem_mb:.2f} MB") | |
| # --- File processing function --- | |
| def process_file(fname): | |
| if not fname.endswith(".txt"): | |
| return None | |
| repo_name = fname[:-4].replace("__", "/") | |
| txt_path = os.path.join(DATA_DIR, fname) | |
| ids_path = os.path.join(DATA_DIR, fname.replace(".txt", ".ids")) | |
| file_paths_path = os.path.join(DATA_DIR, fname.replace(".txt", ".paths")) | |
| try: | |
| with open(txt_path, encoding="utf-8") as f: | |
| content = f.read() | |
| source_ids = [] | |
| if os.path.exists(ids_path): | |
| with open(ids_path, encoding="utf-8") as f_ids: | |
| source_ids = list(set(f_ids.read().splitlines())) | |
| file_paths = [] | |
| if os.path.exists(file_paths_path): | |
| with open(file_paths_path, encoding="utf-8") as f_paths: | |
| file_paths = list(set(f_paths.read().splitlines())) | |
| return { | |
| "max_stars_repo_name": repo_name, | |
| "content": content, | |
| "source_ids": source_ids, | |
| "max_stars_repo_paths": file_paths | |
| } | |
| except Exception as e: | |
| error_log_path = os.path.join(OUTPUT_DIR, "error_log.txt") | |
| with open(error_log_path, "a", encoding="utf-8") as log_f: | |
| log_f.write(f"ERROR processing {fname}: {e}\n") | |
| print(f"[ERROR] Skipping {fname}: {e}") | |
| raise e | |
| # --- Writer (in main process, memory-safe) --- | |
| def push_with_retry(dataset, dataset_name, split_name, token=None, max_retries=5): | |
| for i in range(max_retries): | |
| try: | |
| dataset.push_to_hub(dataset_name, split=split_name, token=token) | |
| return | |
| except Exception as e: | |
| print(e) | |
| if "429" in str(e) or "RateLimitExceeded" in str(e) or "500" in str(e): | |
| wait = 2 ** i | |
| if i == (max_retries - 1): | |
| wait = 50 * 60 | |
| print(f"[WARN] Rate limit hit, retrying in {wait}s (attempt {i+1}/{max_retries})...") | |
| time.sleep(wait) | |
| else: | |
| raise e | |
| raise RuntimeError(f"Failed to push {split_name} after {max_retries} retries due to rate limit.") | |
| def stream_and_push_to_hub( | |
| examples_iterator, | |
| dataset_name, # e.g. "yadapruk/grouped-starcoderdata-merged" | |
| batch_size, | |
| starting_shard_idx, | |
| hf_token=None # Optional: HF token for auth | |
| ): | |
| file_buffer = [] | |
| shard_idx = starting_shard_idx | |
| for example in tqdm(examples_iterator): | |
| if example is None: | |
| continue | |
| file_buffer.append(example) | |
| if len(file_buffer) >= batch_size: | |
| print("start of push") | |
| print_mem_usage() | |
| dataset = Dataset.from_list(file_buffer) | |
| print("loaded dataset") | |
| print_mem_usage() | |
| split_name = f"shard_{shard_idx}" | |
| print(f"[INFO] Pushing {len(file_buffer)} examples to split='{split_name}'...") | |
| push_with_retry(dataset, dataset_name, split_name, token=hf_token) | |
| print("after push") | |
| print_mem_usage() | |
| del dataset | |
| gc.collect() | |
| print("after collection") | |
| print_mem_usage() | |
| file_buffer = [] | |
| shard_idx += 1 | |
| # Save shard index in logger | |
| shard_log_path = os.path.join(OUTPUT_DIR, "shard_log.txt") | |
| with open(shard_log_path, "w", encoding="utf-8") as shard_log_f: | |
| shard_log_f.write(f"{shard_idx}\n") | |
| # Push final incomplete batch | |
| if file_buffer: | |
| df = pd.DataFrame(file_buffer) | |
| dataset = Dataset.from_pandas(df) | |
| split_name = f"shard_{shard_idx}" | |
| print(f"[INFO] Pushing final {len(file_buffer)} examples to split='{split_name}'...") | |
| dataset.push_to_hub(dataset_name, split=split_name, token=hf_token) | |
| return shard_idx | |
| def push_shard(shard_idx): | |
| try: | |
| shard_path = os.path.join(OUTPUT_DIR, f"shard_{shard_idx}") | |
| print(f"[INFO] Pushing shard_{shard_idx}") | |
| ds = load_from_disk(shard_path) | |
| ds.push_to_hub(DATASET_NAME, private=True) | |
| except Exception as e: | |
| error_log_path = os.path.join(OUTPUT_DIR, "error_log.txt") | |
| with open(error_log_path, "a", encoding="utf-8") as log_f: | |
| log_f.write(f"ERROR pushing shard_{shard_idx} to hub: {e}\n") | |
| return None | |
| def main(): | |
| all_files = sorted(os.listdir(DATA_DIR)) | |
| files_to_process = all_files[already_processed_files:] | |
| with mp.Pool(mp.cpu_count() - 1) as pool: | |
| iterator = pool.imap(process_file, files_to_process, chunksize=1) | |
| """ | |
| def iterate(files_to_process): | |
| for fname in files_to_process: | |
| yield process_file(fname) | |
| iterator = iterate(files_to_process) | |
| """ | |
| stream_and_push_to_hub(iterator, DATASET_NAME, BATCH_SIZE, starting_shard_idx) | |
| if __name__ == "__main__": | |
| main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment