Skip to content

Instantly share code, notes, and snippets.

@pruksmhc
Created July 25, 2025 16:26
Show Gist options
  • Select an option

  • Save pruksmhc/2108ebfc97878436e13c86050071f3bc to your computer and use it in GitHub Desktop.

Select an option

Save pruksmhc/2108ebfc97878436e13c86050071f3bc to your computer and use it in GitHub Desktop.
import os
import re
import pandas as pd
import multiprocessing as mp
import time
from datasets import Dataset, load_from_disk
import gc
from tqdm import tqdm
import psutil
import os
DATA_DIR = "grouped_by_repo"
OUTPUT_DIR = "hf_dataset_shards_optimized_new"
DATASET_NAME = "yadapruk/grouped-starcoderdata-merged"
BATCH_SIZE = 10000
os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Detect existing shards ---
existing_shards = sorted([
int(re.search(r"shard_(\d+)", f).group(1))
for f in os.listdir(OUTPUT_DIR)
if f.startswith("shard_") and os.path.isdir(os.path.join(OUTPUT_DIR, f))
])
starting_shard_idx = int(open(os.path.join(OUTPUT_DIR, "shard_log.txt"), "r").read().strip())
already_processed_files = starting_shard_idx * BATCH_SIZE
print(f"[INFO] Resuming from shard {starting_shard_idx}")
print(f"[INFO] Skipping first {already_processed_files} examples")
def print_mem_usage():
process = psutil.Process(os.getpid())
mem_mb = process.memory_info().rss / 1e6 # Resident Set Size in MB
print(f"[MEM] Memory usage: {mem_mb:.2f} MB")
# --- File processing function ---
def process_file(fname):
if not fname.endswith(".txt"):
return None
repo_name = fname[:-4].replace("__", "/")
txt_path = os.path.join(DATA_DIR, fname)
ids_path = os.path.join(DATA_DIR, fname.replace(".txt", ".ids"))
file_paths_path = os.path.join(DATA_DIR, fname.replace(".txt", ".paths"))
try:
with open(txt_path, encoding="utf-8") as f:
content = f.read()
source_ids = []
if os.path.exists(ids_path):
with open(ids_path, encoding="utf-8") as f_ids:
source_ids = list(set(f_ids.read().splitlines()))
file_paths = []
if os.path.exists(file_paths_path):
with open(file_paths_path, encoding="utf-8") as f_paths:
file_paths = list(set(f_paths.read().splitlines()))
return {
"max_stars_repo_name": repo_name,
"content": content,
"source_ids": source_ids,
"max_stars_repo_paths": file_paths
}
except Exception as e:
error_log_path = os.path.join(OUTPUT_DIR, "error_log.txt")
with open(error_log_path, "a", encoding="utf-8") as log_f:
log_f.write(f"ERROR processing {fname}: {e}\n")
print(f"[ERROR] Skipping {fname}: {e}")
raise e
# --- Writer (in main process, memory-safe) ---
def push_with_retry(dataset, dataset_name, split_name, token=None, max_retries=5):
for i in range(max_retries):
try:
dataset.push_to_hub(dataset_name, split=split_name, token=token)
return
except Exception as e:
print(e)
if "429" in str(e) or "RateLimitExceeded" in str(e) or "500" in str(e):
wait = 2 ** i
if i == (max_retries - 1):
wait = 50 * 60
print(f"[WARN] Rate limit hit, retrying in {wait}s (attempt {i+1}/{max_retries})...")
time.sleep(wait)
else:
raise e
raise RuntimeError(f"Failed to push {split_name} after {max_retries} retries due to rate limit.")
def stream_and_push_to_hub(
examples_iterator,
dataset_name, # e.g. "yadapruk/grouped-starcoderdata-merged"
batch_size,
starting_shard_idx,
hf_token=None # Optional: HF token for auth
):
file_buffer = []
shard_idx = starting_shard_idx
for example in tqdm(examples_iterator):
if example is None:
continue
file_buffer.append(example)
if len(file_buffer) >= batch_size:
print("start of push")
print_mem_usage()
dataset = Dataset.from_list(file_buffer)
print("loaded dataset")
print_mem_usage()
split_name = f"shard_{shard_idx}"
print(f"[INFO] Pushing {len(file_buffer)} examples to split='{split_name}'...")
push_with_retry(dataset, dataset_name, split_name, token=hf_token)
print("after push")
print_mem_usage()
del dataset
gc.collect()
print("after collection")
print_mem_usage()
file_buffer = []
shard_idx += 1
# Save shard index in logger
shard_log_path = os.path.join(OUTPUT_DIR, "shard_log.txt")
with open(shard_log_path, "w", encoding="utf-8") as shard_log_f:
shard_log_f.write(f"{shard_idx}\n")
# Push final incomplete batch
if file_buffer:
df = pd.DataFrame(file_buffer)
dataset = Dataset.from_pandas(df)
split_name = f"shard_{shard_idx}"
print(f"[INFO] Pushing final {len(file_buffer)} examples to split='{split_name}'...")
dataset.push_to_hub(dataset_name, split=split_name, token=hf_token)
return shard_idx
def push_shard(shard_idx):
try:
shard_path = os.path.join(OUTPUT_DIR, f"shard_{shard_idx}")
print(f"[INFO] Pushing shard_{shard_idx}")
ds = load_from_disk(shard_path)
ds.push_to_hub(DATASET_NAME, private=True)
except Exception as e:
error_log_path = os.path.join(OUTPUT_DIR, "error_log.txt")
with open(error_log_path, "a", encoding="utf-8") as log_f:
log_f.write(f"ERROR pushing shard_{shard_idx} to hub: {e}\n")
return None
def main():
all_files = sorted(os.listdir(DATA_DIR))
files_to_process = all_files[already_processed_files:]
with mp.Pool(mp.cpu_count() - 1) as pool:
iterator = pool.imap(process_file, files_to_process, chunksize=1)
"""
def iterate(files_to_process):
for fname in files_to_process:
yield process_file(fname)
iterator = iterate(files_to_process)
"""
stream_and_push_to_hub(iterator, DATASET_NAME, BATCH_SIZE, starting_shard_idx)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment