Created
November 17, 2025 19:24
-
-
Save grovduck/751179d933f6ed3523fcd861ded890b7 to your computer and use it in GitHub Desktop.
Demonstrate using `sklearn_raster.features.FeatureArray` with custom ufunc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import dask | |
| from dask.distributed import Client | |
| from sknnr import EuclideanKNNRegressor | |
| from sklearn_raster import FeatureArrayEstimator | |
| from sklearn_raster.datasets import load_swo_ecoplot | |
| if __name__ == "__main__": | |
| # Load example dataset | |
| X_img, X, y = load_swo_ecoplot(as_dataset=True, large_rasters=True) | |
| # Create and fit the estimator | |
| est = FeatureArrayEstimator(EuclideanKNNRegressor(n_neighbors=5)).fit(X, y) | |
| # Create the neighbors and distances | |
| with Client(): | |
| distances, neighbors = est.kneighbors( | |
| X_img, return_distance=True, return_dataframe_index=True | |
| ) | |
| distances, neighbors = dask.compute(distances, neighbors) | |
| distances.rio.to_raster("swo_ecoplot_distances.tif") | |
| neighbors.rio.to_raster("swo_ecoplot_neighbors.tif") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import wraps | |
| import numpy as np | |
| import rioxarray as rxr | |
| import sknnr.datasets | |
| import xarray as xr | |
| from dask.distributed import Client | |
| from sklearn_raster.features import FeatureArray | |
| def vectorize_first_arg(func): | |
| @wraps(func) | |
| def wrapper(arr, *args, **kwargs): | |
| arr = np.asarray(arr) | |
| # If 1D input, call directly | |
| if arr.ndim == 1: | |
| return func(arr, *args, **kwargs) | |
| results = [] | |
| for row in arr: | |
| results.append(func(row, *args, **kwargs)) | |
| # If function returns multiple outputs (tuple/list) | |
| if isinstance(results[0], (tuple, list)): | |
| n_out = len(results[0]) | |
| stacked = [] | |
| for i in range(n_out): | |
| ith_outputs = [r[i] for r in results] | |
| # Make each output column-shaped (n_samples, 1) | |
| ith_outputs = [np.asarray(o).reshape(1, -1) for o in ith_outputs] | |
| stacked.append(np.vstack(ith_outputs)) | |
| return tuple(stacked) | |
| # Single-output function | |
| results = [np.asarray(r).reshape(1, -1) for r in results] | |
| return np.vstack(results) | |
| return wrapper | |
| @vectorize_first_arg | |
| def weighted_attribute(vec, k, keys, values): | |
| """Compute weighted attribute value for a single pixel over k neighbors.""" | |
| nn_vec = vec[:k] | |
| dist_vec = vec[k : 2 * k] | |
| weight_vec = 1.0 / (1.0 + dist_vec) | |
| weight_vec = weight_vec / weight_vec.sum() | |
| attr_vec = values[np.searchsorted(keys, nn_vec)] | |
| return (attr_vec * weight_vec).sum() | |
| if __name__ == "__main__": | |
| k = 5 | |
| chunks = {"x": 1024, "y": 1024} | |
| # Read in the neighbors and distances rasters and stack them into a | |
| # single DataArray | |
| nn_da = rxr.open_rasterio("swo_ecoplot_neighbors.tif", chunks=chunks).squeeze() | |
| dist_da = rxr.open_rasterio("swo_ecoplot_distances.tif", chunks=chunks).squeeze() | |
| neighbors_da = xr.concat([nn_da, dist_da], dim="band") | |
| # Read in the SWO attribute table and store the index and PSME_COV | |
| _, y = sknnr.datasets.load_swo_ecoplot(return_X_y=True, as_frame=True) | |
| keys, values = y.index.values, y.PSME_COV.values | |
| # Convert the neighbors_da into a sklearn_raster FeatureArray | |
| features = FeatureArray.from_feature_array(neighbors_da) | |
| with Client(): | |
| # Calculate the weighted attribute and save to raster | |
| weighted_attr_data = features.apply_ufunc_across_features( | |
| weighted_attribute, | |
| output_dims=[["band"]], | |
| output_dtypes=[np.float32], | |
| output_sizes={"band": 1}, | |
| nodata_output=np.float32(-1.0), | |
| keep_attrs=False, | |
| k=k, | |
| keys=keys, | |
| values=values, | |
| ) | |
| weighted_attr_data.rio.to_raster("swo_ecoplot_psme_cov.tif") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment