Skip to content

Instantly share code, notes, and snippets.

@dfulu
Created October 11, 2024 20:52
Show Gist options
  • Select an option

  • Save dfulu/e160f56b184801bbbe2030e0d4b33d65 to your computer and use it in GitHub Desktop.

Select an option

Save dfulu/e160f56b184801bbbe2030e0d4b33d65 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "670eee2e-84d1-4a74-a0ae-b442bba929a1",
"metadata": {},
"outputs": [],
"source": [
"import dask.array as da\n",
"import numpy as np\n",
"import xarray as xr\n",
"import time\n",
"import zarr\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "088cce47-956a-43e7-9bc0-b326ad5e828b",
"metadata": {},
"outputs": [],
"source": [
"# Start by creating some sample data we can load from\n",
"\n",
"# We'll save this many zarr files\n",
"num_zarr_files = 2\n",
"\n",
"# They will be asigned file names likle this\n",
"zarr_file_pattern = \"data_partition_{}.zarr\"\n",
"\n",
"# Choose a compressor - I've only tested this and None\n",
"compressor = zarr.Blosc(cname=\"zstd\", clevel=3, shuffle=2)\n",
"\n",
"# We'll make an array with these dimensions\n",
"dims = [\"time\", \"x\", \"y\"]\n",
"\n",
"# We'll use these chunk sizes\n",
"chunk_sizes = {\n",
" \"time\": 10,\n",
" \"x\": 30,\n",
" \"y\": 30,\n",
"}\n",
"\n",
"# And these number of chunks along each dimension\n",
"number_of_chunks = {\n",
" \"time\": 100,\n",
" \"x\": 5,\n",
" \"y\": 5,\n",
"}\n",
"\n",
"\n",
"# Create an array of random data. We'll save the same underlying data in each zarr but with different coords\n",
"# This array is about 170MB\n",
"dask_array = da.random.random(\n",
" size=tuple(chunk_sizes[dim]*number_of_chunks[dim] for dim in dims), \n",
" chunks=tuple(chunk_sizes[dim] for dim in dims)\n",
")\n",
"\n",
"# Convert to xarray Dataset\n",
"coords = {dim: np.arange(n) for n, dim in zip(dask_array.shape, dims)}\n",
"\n",
"ds = xr.DataArray(\n",
" dask_array, \n",
" dims=dims, \n",
" coords=coords,\n",
").to_dataset(name=\"data\")\n",
"\n",
"\n",
"# Loop through and save copies of the array\n",
"zarr_paths = []\n",
"\n",
"for i in range(num_zarr_files):\n",
" \n",
" # Shift the time coords to we can concat all the zarrs later\n",
" ds_partition = ds.copy(deep=True)\n",
" ds_partition[\"time\"] = ds[\"time\"] + len(coords[\"time\"])\n",
" \n",
" save_path = zarr_file_pattern.format(i)\n",
" \n",
" ds_partition.to_zarr(save_path, encoding={\"data\": {\"compressor\": compressor}})\n",
" \n",
" zarr_paths.append(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "840489b0-a555-452e-a30e-04ac316a01ae",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def load_samples(ds, num_samples=100, t_size=6, x_size=10, y_size=10, seed=0):\n",
" \"\"\"Load a small slice of the data from the input dataset mutliple times\n",
" \n",
" Args:\n",
" ds: The xarray dataset\n",
" num_samples: The number of samples to load\n",
" t_size: The size of the sample window in the time dimension\n",
" x_size: The size of the sample window in the x dimension\n",
" y_size: The size of the sample window in the y dimension\n",
" seed: The random seed\n",
" \"\"\"\n",
" \n",
" # Control for the random seed\n",
" np.random.seed(seed=seed)\n",
"\n",
" # Sample many times\n",
" for i in range(num_samples):\n",
" \n",
" # Randomly select the starting index for this sample across all dimensions\n",
" t_i = np.random.randint(0, len(ds[\"time\"])-t_size)\n",
" x_i = np.random.randint(0, len(ds[\"x\"])-x_size)\n",
" y_i = np.random.randint(0, len(ds[\"y\"])-y_size)\n",
" \n",
" # Load the sample\n",
" ds_sel = ds.isel(\n",
" {\n",
" \"time\": slice(t_i, t_i+t_size),\n",
" \"x\": slice(x_i, x_i+x_size),\n",
" \"y\": slice(y_i, y_i+y_size),\n",
" }\n",
" ).load()\n",
" \n",
" return\n",
"\n",
"\n",
"def get_chunk_sizes(ds):\n",
" \"\"\"Get the chunk sizes of dataset in a readible format\"\"\"\n",
" chunk_sizes = {}\n",
" for c in list(ds.coords):\n",
" if len(np.unique(ds.chunks[c]))==1:\n",
" chunk_sizes[c] = ds.chunks[c][0]\n",
" else:\n",
" chunk_sizes[c] = ds.chunks[c]\n",
" return chunk_sizes\n",
"\n",
"\n",
"def test_chunking_method_load_speed(zarr_paths, chunks, num_samples=100, t_size=6, x_size=10, y_size=10, seed=0):\n",
" \"\"\"Test the speed of loading a number of samples from an xarray dataset loaded with different chunking\n",
" \n",
" Args:\n",
" zarr_paths: A list of zarr paths to load\n",
" chunks: The chunking method to pass to `xr.open_mfdataset()`\n",
" num_samples: The number of samples to load\n",
" t_size: The size of the sample window in the time dimension\n",
" x_size: The size of the sample window in the x dimension\n",
" y_size: The size of the sample window in the y dimension\n",
" seed: The random seed\n",
" \"\"\"\n",
" \n",
" ds = xr.open_mfdataset(\n",
" zarr_paths,\n",
" engine=\"zarr\",\n",
" concat_dim=\"time\",\n",
" combine=\"nested\",\n",
" chunks=chunks,\n",
" )\n",
" \n",
" # Print the size of the chunks in the dataset\n",
" print(f\"{ds.data.shape=}\")\n",
" print(\"Chunk sizes: \", get_chunk_sizes(ds))\n",
" \n",
" # Measure the time taken to load samples from the data\n",
" t0 = time.time() \n",
" load_samples(ds, num_samples=num_samples, t_size=t_size, x_size=x_size, y_size=y_size, seed=seed)\n",
" print(f\"Sample load time: {time.time()-t0:.2f} secs\")\n",
" \n",
" return ds\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c061ce04-919e-4710-af92-b0a8f778ccaa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"chunks='auto'\n",
"ds.data.shape=(2000, 150, 150)\n",
"Chunk sizes: {'time': (740, 260, 740, 260), 'x': 150, 'y': 150}\n",
"Sample load time: 2.47 secs\n",
"\n",
"\n",
"chunks={}\n",
"ds.data.shape=(2000, 150, 150)\n",
"Chunk sizes: {'time': 10, 'x': 30, 'y': 30}\n",
"Sample load time: 7.86 secs\n",
"\n",
"\n"
]
}
],
"source": [
"for chunks in [\"auto\", {}]:\n",
" print(f\"{chunks=}\")\n",
" \n",
" ds = test_chunking_method_load_speed(\n",
" zarr_paths, \n",
" chunks=chunks, \n",
" num_samples=1000,\n",
" t_size=5, \n",
" x_size=10, \n",
" y_size=10,\n",
" seed=0\n",
" )\n",
" \n",
" print(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9ce2c22b-050c-4e04-aecd-70fe74b8287b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"chunks='auto'\n",
"ds.data.shape=(2000, 150, 150)\n",
"Chunk sizes: {'time': (740, 260, 740, 260), 'x': 150, 'y': 150}\n",
"Sample load time: 2.25 secs\n",
"\n",
"\n",
"chunks={}\n",
"ds.data.shape=(2000, 150, 150)\n",
"Chunk sizes: {'time': 10, 'x': 30, 'y': 30}\n",
"Sample load time: 6.94 secs\n",
"\n",
"\n"
]
}
],
"source": [
"for chunks in [\"auto\", {}]:\n",
" print(f\"{chunks=}\")\n",
" \n",
" ds = test_chunking_method_load_speed(\n",
" zarr_paths, \n",
" chunks=chunks, \n",
" num_samples=1000,\n",
" t_size=1, \n",
" x_size=1, \n",
" y_size=1,\n",
" seed=0\n",
" )\n",
" \n",
" print(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "580edb55-b041-4ca9-9f89-e7c8db1516d2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Clean up\n",
"os.system(\"rm -r \"+zarr_file_pattern.format(\"*\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fb69d34-58a3-4922-88c8-e045456a6843",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "pvnet0",
"name": "pytorch-gpu.1-13.m103",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m103"
},
"kernelspec": {
"display_name": "pvnet0",
"language": "python",
"name": "pvnet0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment