Created
October 11, 2024 20:52
-
-
Save dfulu/e160f56b184801bbbe2030e0d4b33d65 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "670eee2e-84d1-4a74-a0ae-b442bba929a1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import dask.array as da\n", | |
| "import numpy as np\n", | |
| "import xarray as xr\n", | |
| "import time\n", | |
| "import zarr\n", | |
| "import os" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "088cce47-956a-43e7-9bc0-b326ad5e828b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Start by creating some sample data we can load from\n", | |
| "\n", | |
| "# We'll save this many zarr files\n", | |
| "num_zarr_files = 2\n", | |
| "\n", | |
| "# They will be asigned file names likle this\n", | |
| "zarr_file_pattern = \"data_partition_{}.zarr\"\n", | |
| "\n", | |
| "# Choose a compressor - I've only tested this and None\n", | |
| "compressor = zarr.Blosc(cname=\"zstd\", clevel=3, shuffle=2)\n", | |
| "\n", | |
| "# We'll make an array with these dimensions\n", | |
| "dims = [\"time\", \"x\", \"y\"]\n", | |
| "\n", | |
| "# We'll use these chunk sizes\n", | |
| "chunk_sizes = {\n", | |
| " \"time\": 10,\n", | |
| " \"x\": 30,\n", | |
| " \"y\": 30,\n", | |
| "}\n", | |
| "\n", | |
| "# And these number of chunks along each dimension\n", | |
| "number_of_chunks = {\n", | |
| " \"time\": 100,\n", | |
| " \"x\": 5,\n", | |
| " \"y\": 5,\n", | |
| "}\n", | |
| "\n", | |
| "\n", | |
| "# Create an array of random data. We'll save the same underlying data in each zarr but with different coords\n", | |
| "# This array is about 170MB\n", | |
| "dask_array = da.random.random(\n", | |
| " size=tuple(chunk_sizes[dim]*number_of_chunks[dim] for dim in dims), \n", | |
| " chunks=tuple(chunk_sizes[dim] for dim in dims)\n", | |
| ")\n", | |
| "\n", | |
| "# Convert to xarray Dataset\n", | |
| "coords = {dim: np.arange(n) for n, dim in zip(dask_array.shape, dims)}\n", | |
| "\n", | |
| "ds = xr.DataArray(\n", | |
| " dask_array, \n", | |
| " dims=dims, \n", | |
| " coords=coords,\n", | |
| ").to_dataset(name=\"data\")\n", | |
| "\n", | |
| "\n", | |
| "# Loop through and save copies of the array\n", | |
| "zarr_paths = []\n", | |
| "\n", | |
| "for i in range(num_zarr_files):\n", | |
| " \n", | |
| " # Shift the time coords to we can concat all the zarrs later\n", | |
| " ds_partition = ds.copy(deep=True)\n", | |
| " ds_partition[\"time\"] = ds[\"time\"] + len(coords[\"time\"])\n", | |
| " \n", | |
| " save_path = zarr_file_pattern.format(i)\n", | |
| " \n", | |
| " ds_partition.to_zarr(save_path, encoding={\"data\": {\"compressor\": compressor}})\n", | |
| " \n", | |
| " zarr_paths.append(save_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "840489b0-a555-452e-a30e-04ac316a01ae", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\n", | |
| "def load_samples(ds, num_samples=100, t_size=6, x_size=10, y_size=10, seed=0):\n", | |
| " \"\"\"Load a small slice of the data from the input dataset mutliple times\n", | |
| " \n", | |
| " Args:\n", | |
| " ds: The xarray dataset\n", | |
| " num_samples: The number of samples to load\n", | |
| " t_size: The size of the sample window in the time dimension\n", | |
| " x_size: The size of the sample window in the x dimension\n", | |
| " y_size: The size of the sample window in the y dimension\n", | |
| " seed: The random seed\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " # Control for the random seed\n", | |
| " np.random.seed(seed=seed)\n", | |
| "\n", | |
| " # Sample many times\n", | |
| " for i in range(num_samples):\n", | |
| " \n", | |
| " # Randomly select the starting index for this sample across all dimensions\n", | |
| " t_i = np.random.randint(0, len(ds[\"time\"])-t_size)\n", | |
| " x_i = np.random.randint(0, len(ds[\"x\"])-x_size)\n", | |
| " y_i = np.random.randint(0, len(ds[\"y\"])-y_size)\n", | |
| " \n", | |
| " # Load the sample\n", | |
| " ds_sel = ds.isel(\n", | |
| " {\n", | |
| " \"time\": slice(t_i, t_i+t_size),\n", | |
| " \"x\": slice(x_i, x_i+x_size),\n", | |
| " \"y\": slice(y_i, y_i+y_size),\n", | |
| " }\n", | |
| " ).load()\n", | |
| " \n", | |
| " return\n", | |
| "\n", | |
| "\n", | |
| "def get_chunk_sizes(ds):\n", | |
| " \"\"\"Get the chunk sizes of dataset in a readible format\"\"\"\n", | |
| " chunk_sizes = {}\n", | |
| " for c in list(ds.coords):\n", | |
| " if len(np.unique(ds.chunks[c]))==1:\n", | |
| " chunk_sizes[c] = ds.chunks[c][0]\n", | |
| " else:\n", | |
| " chunk_sizes[c] = ds.chunks[c]\n", | |
| " return chunk_sizes\n", | |
| "\n", | |
| "\n", | |
| "def test_chunking_method_load_speed(zarr_paths, chunks, num_samples=100, t_size=6, x_size=10, y_size=10, seed=0):\n", | |
| " \"\"\"Test the speed of loading a number of samples from an xarray dataset loaded with different chunking\n", | |
| " \n", | |
| " Args:\n", | |
| " zarr_paths: A list of zarr paths to load\n", | |
| " chunks: The chunking method to pass to `xr.open_mfdataset()`\n", | |
| " num_samples: The number of samples to load\n", | |
| " t_size: The size of the sample window in the time dimension\n", | |
| " x_size: The size of the sample window in the x dimension\n", | |
| " y_size: The size of the sample window in the y dimension\n", | |
| " seed: The random seed\n", | |
| " \"\"\"\n", | |
| " \n", | |
| " ds = xr.open_mfdataset(\n", | |
| " zarr_paths,\n", | |
| " engine=\"zarr\",\n", | |
| " concat_dim=\"time\",\n", | |
| " combine=\"nested\",\n", | |
| " chunks=chunks,\n", | |
| " )\n", | |
| " \n", | |
| " # Print the size of the chunks in the dataset\n", | |
| " print(f\"{ds.data.shape=}\")\n", | |
| " print(\"Chunk sizes: \", get_chunk_sizes(ds))\n", | |
| " \n", | |
| " # Measure the time taken to load samples from the data\n", | |
| " t0 = time.time() \n", | |
| " load_samples(ds, num_samples=num_samples, t_size=t_size, x_size=x_size, y_size=y_size, seed=seed)\n", | |
| " print(f\"Sample load time: {time.time()-t0:.2f} secs\")\n", | |
| " \n", | |
| " return ds\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "c061ce04-919e-4710-af92-b0a8f778ccaa", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "chunks='auto'\n", | |
| "ds.data.shape=(2000, 150, 150)\n", | |
| "Chunk sizes: {'time': (740, 260, 740, 260), 'x': 150, 'y': 150}\n", | |
| "Sample load time: 2.47 secs\n", | |
| "\n", | |
| "\n", | |
| "chunks={}\n", | |
| "ds.data.shape=(2000, 150, 150)\n", | |
| "Chunk sizes: {'time': 10, 'x': 30, 'y': 30}\n", | |
| "Sample load time: 7.86 secs\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for chunks in [\"auto\", {}]:\n", | |
| " print(f\"{chunks=}\")\n", | |
| " \n", | |
| " ds = test_chunking_method_load_speed(\n", | |
| " zarr_paths, \n", | |
| " chunks=chunks, \n", | |
| " num_samples=1000,\n", | |
| " t_size=5, \n", | |
| " x_size=10, \n", | |
| " y_size=10,\n", | |
| " seed=0\n", | |
| " )\n", | |
| " \n", | |
| " print(\"\\n\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "9ce2c22b-050c-4e04-aecd-70fe74b8287b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "chunks='auto'\n", | |
| "ds.data.shape=(2000, 150, 150)\n", | |
| "Chunk sizes: {'time': (740, 260, 740, 260), 'x': 150, 'y': 150}\n", | |
| "Sample load time: 2.25 secs\n", | |
| "\n", | |
| "\n", | |
| "chunks={}\n", | |
| "ds.data.shape=(2000, 150, 150)\n", | |
| "Chunk sizes: {'time': 10, 'x': 30, 'y': 30}\n", | |
| "Sample load time: 6.94 secs\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for chunks in [\"auto\", {}]:\n", | |
| " print(f\"{chunks=}\")\n", | |
| " \n", | |
| " ds = test_chunking_method_load_speed(\n", | |
| " zarr_paths, \n", | |
| " chunks=chunks, \n", | |
| " num_samples=1000,\n", | |
| " t_size=1, \n", | |
| " x_size=1, \n", | |
| " y_size=1,\n", | |
| " seed=0\n", | |
| " )\n", | |
| " \n", | |
| " print(\"\\n\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "580edb55-b041-4ca9-9f89-e7c8db1516d2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Clean up\n", | |
| "os.system(\"rm -r \"+zarr_file_pattern.format(\"*\"))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "9fb69d34-58a3-4922-88c8-e045456a6843", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "environment": { | |
| "kernel": "pvnet0", | |
| "name": "pytorch-gpu.1-13.m103", | |
| "type": "gcloud", | |
| "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m103" | |
| }, | |
| "kernelspec": { | |
| "display_name": "pvnet0", | |
| "language": "python", | |
| "name": "pvnet0" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.13" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment