Skip to content

Instantly share code, notes, and snippets.

@maxisoft
Created September 3, 2025 11:57
Show Gist options
  • Select an option

  • Save maxisoft/ee7895c37142228fb05d151fcbfa4719 to your computer and use it in GitHub Desktop.

Select an option

Save maxisoft/ee7895c37142228fb05d151fcbfa4719 to your computer and use it in GitHub Desktop.
A highly optimized, fully vectorized Array-to-Array zarr v3 codec that applies a temporal transform (XOR or Delta)
import zarr
import numpy as np
from dataclasses import dataclass, replace
from typing import Literal
from zarr.abc.codec import ArrayArrayCodec
from zarr.core.array_spec import ArraySpec
from zarr.core.buffer import NDBuffer
from zarr.core.common import parse_named_configuration
CODEC_NAME = "cryptodd.vectorized_temporal"
def _get_uint_equivalent(dtype: np.dtype) -> np.dtype:
"""
Gets the unsigned integer equivalent for a float dtype for bitwise operations.
"""
if dtype.kind not in ('f', 'i', 'u'):
raise TypeError(f"XOR transform is intended for float data, not {dtype}.")
# itemsize is in bytes, so multiply by 8 for bits
bits = dtype.itemsize * 8
if bits not in (16, 32, 64):
raise TypeError(f"Unsupported float bit size for XOR transform: float{bits}")
return np.dtype(f'uint{bits}')
# ---- The Final, Optimized "Super Codec" ----
@dataclass(frozen=True)
class VectorizedTemporalCodec(ArrayArrayCodec):
"""
A highly optimized, fully vectorized Array-to-Array codec that applies a
temporal transform (XOR or Delta) along a specified axis without using
Python loops, transpositions, or large intermediate copies.
This codec is shape-preserving and acts as a pure filter.
"""
is_fixed_size = True
transform: Literal["xor", "delta", "double_delta"]
axis: int = 0 # Assume time is the first axis by default
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
return input_byte_length
def to_dict(self):
return {
"name": CODEC_NAME,
"configuration": {
"transform": self.transform,
"axis": self.axis,
},
}
@classmethod
def from_dict(cls, data):
_, configuration_parsed = parse_named_configuration(data, CODEC_NAME)
return cls(**configuration_parsed) # type: ignore[arg-type]
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
# This codec is a pure filter; it does not change the chunk shape or type.
return chunk_spec
# --- Encoding (Original -> Transformed) ---
async def _encode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer | None:
# Get a mutable numpy array view of the buffer. Operations will be in-place.
data = chunk_array.as_numpy_array()
# Slicing helper to select all other axes
all_axes = [slice(None)] * data.ndim
# Define slices for operating on adjacent elements along the transform axis.
# `target_slice` is equivalent to `data[1:]` along the axis.
# `previous_slice` is equivalent to `data[:-1]` along the axis.
target_slice = tuple(all_axes[:self.axis] + [slice(1, None)] + all_axes[self.axis + 1:])
previous_slice = tuple(all_axes[:self.axis] + [slice(None, -1)] + all_axes[self.axis + 1:])
if self.transform == "xor":
native_dtype = chunk_spec.dtype.to_native_dtype()
uint_dtype = _get_uint_equivalent(native_dtype)
# Perform vectorized XOR. This in-place operation is safe because NumPy's
# ufuncs use internal buffering, reading all inputs before writing to the output.
data_as_uint = data.view(uint_dtype)
np.bitwise_xor(data_as_uint[target_slice], data_as_uint[previous_slice], out=data_as_uint[target_slice])
elif self.transform == "delta":
# Perform vectorized subtraction. Like XOR, this in-place operation is safe
# due to NumPy's internal ufunc buffering.
np.subtract(data[target_slice], data[previous_slice], out=data[target_slice])
elif self.transform == "double_delta":
# First delta pass (in-place).
np.subtract(data[target_slice], data[previous_slice], out=data[target_slice])
# For the second delta, we must operate on the result of the first.
# A copy of the intermediate delta values is required because the next
# operation would otherwise read from data that it is also modifying.
temp_delta = data[previous_slice].copy()
np.subtract(data[target_slice], temp_delta, out=data[target_slice])
return chunk_array.from_numpy_array(data)
# --- Decoding (Transformed -> Original) ---
async def _decode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer:
if self.transform == "xor":
native_dtype = chunk_spec.dtype.to_native_dtype()
uint_dtype = _get_uint_equivalent(native_dtype)
# Use the ufunc's accumulate method for a fast, vectorized cumulative XOR
decoded_int = np.bitwise_xor.accumulate(chunk_array.as_numpy_array().view(uint_dtype), axis=self.axis)
return chunk_array.from_numpy_array(decoded_int.view(native_dtype))
elif self.transform == "delta":
# Use np.add.accumulate (equivalent to np.cumsum)
return chunk_array.from_numpy_array(np.add.accumulate(chunk_array.as_numpy_array(), axis=self.axis, dtype=chunk_spec.dtype.to_native_dtype()))
elif self.transform == "double_delta":
# Apply cumulative sum twice
temp = np.add.accumulate(chunk_array.as_numpy_array(), axis=self.axis, dtype=chunk_spec.dtype.to_native_dtype())
return chunk_array.from_numpy_array(np.add.accumulate(temp, axis=self.axis, dtype=chunk_spec.dtype.to_native_dtype()))
raise RuntimeError("Unknown transform type")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment