Skip to content

Instantly share code, notes, and snippets.

@agcom
Last active November 2, 2025 13:47
Show Gist options
  • Select an option

  • Save agcom/963f6a494b4f20fc8437380289e5a63f to your computer and use it in GitHub Desktop.

Select an option

Save agcom/963f6a494b4f20fc8437380289e5a63f to your computer and use it in GitHub Desktop.
Creative stratified split of a dataset using scikit-learn's `train_test_split` function recursively
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import random
from random import Random
from typing import Any
from sklearn.model_selection import train_test_split
def stratified_split(*arrays, labels, n_splits: int, rnd: Random | None = None) -> tuple[Any, ...]:
if n_splits == 1:
return arrays
else:
splits = train_test_split(
*arrays, labels,
train_size=1 / n_splits,
# Found the following range from its respective error message if passed a random float.
random_state=rnd.randint(0, 4294967295) if rnd is not None else random.randint(0, 4294967295),
stratify=labels
)
labels_split, labels_rest = splits[-2:]
arrays_splits = splits[:-2]
return *arrays_splits[::2], *stratified_split(*arrays_splits[1::2], labels=labels_rest, n_splits=n_splits - 1, rnd=rnd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment