Skip to content

Instantly share code, notes, and snippets.

@matt-long
Last active April 22, 2020 15:35
Show Gist options
  • Select an option

  • Save matt-long/9c1efa02ad08e5f5d29539b4cab54d3c to your computer and use it in GitHub Desktop.

Select an option

Save matt-long/9c1efa02ad08e5f5d29539b4cab54d3c to your computer and use it in GitHub Desktop.
Compute annual means using xarray
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Demonstrate computing annual means from monthly data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import xarray as xr\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define some helper function\n",
"\n",
"Sometimes the value of time is at the end of the averaging interval, which screws up groupby operations. Define a function to center time."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def center_time(ds, time_bnds_varname=None):\n",
" \"\"\"reset time in a dataset to be the mean of the time_bounds variable \"\"\"\n",
"\n",
" if time_bnds_varname is None:\n",
" try:\n",
" time_bnds_varname = ds.time.bounds\n",
" except:\n",
" raise ValueError('could not determine time bounds variable name')\n",
" \n",
" time_bnds_dim2 = ds[time_bnds_varname].dims[1]\n",
" \n",
" time_attrs = ds.time.attrs\n",
" ds['time'] = ds[time_bnds_varname].mean(time_bnds_dim2)\n",
" ds.time.attrs = time_attrs\n",
"\n",
" return ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define a function to compute annual means. This computes averaging weights as the diff of the `time_bounds` variable."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def calc_ann_mean(ds, time_bnds_varname='time_bound'):\n",
"\n",
" group_by_year = 'time.year'\n",
" rename = {'year': 'time'}\n",
"\n",
" # compute time weights as time-bound diff\n",
" time_wgt = ds[time_bnds_varname].diff(dim=ds[time_bnds_varname].dims[1])\n",
" time_wgt_grouped = time_wgt.groupby(group_by_year)\n",
" time_wgt = time_wgt_grouped / time_wgt_grouped.sum(dim=xr.ALL_DIMS)\n",
"\n",
" nyr = len(time_wgt_grouped.groups)\n",
" time_wgt = time_wgt.squeeze()\n",
"\n",
" # ensure that weights sum to 1\n",
" np.testing.assert_almost_equal(time_wgt.groupby(group_by_year).sum(dim=xr.ALL_DIMS), \n",
" np.ones(nyr))\n",
"\n",
" # set non-time related vars to coords to avoid xarray adding a time-dim\n",
" nontime_vars = set([v for v in ds.variables if 'time' not in ds[v].dims]) - set(ds.coords)\n",
" dsop = ds.set_coords(nontime_vars).drop(time_bnds_varname)\n",
"\n",
" # compute the annual means\n",
" ds_ann = (dsop * time_wgt).groupby(group_by_year).sum(dim='time')\n",
"\n",
" # copy attrs \n",
" for v in ds_ann:\n",
" ds_ann[v].attrs = ds[v].attrs\n",
"\n",
" # rename time and put back the coords variable\n",
" ds_ann = ds_ann.reset_coords(nontime_vars).rename(rename)\n",
"\n",
" return ds_ann"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make a dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>&lt;xarray.Dataset&gt;\n",
"Dimensions: (d2: 2, lat: 2, lon: 2, time: 24)\n",
"Coordinates:\n",
" * lat (lat) int64 0 1\n",
" * lon (lon) int64 0 1\n",
" * time (time) object 0001-02-01 00:00:00 ... 0003-01-01 00:00:00\n",
" * d2 (d2) int64 0 1\n",
"Data variables:\n",
" time_bound (time, d2) object ...\n",
" variable_1 (time, lat, lon) float32 ...\n",
" variable_2 (time, lat, lon) float32 ...\n",
" non_time_variable_1 (lat, lon) float64 ...</pre>"
],
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (d2: 2, lat: 2, lon: 2, time: 24)\n",
"Coordinates:\n",
" * lat (lat) int64 0 1\n",
" * lon (lon) int64 0 1\n",
" * time (time) object 0001-02-01 00:00:00 ... 0003-01-01 00:00:00\n",
" * d2 (d2) int64 0 1\n",
"Data variables:\n",
" time_bound (time, d2) object ...\n",
" variable_1 (time, lat, lon) float32 ...\n",
" variable_2 (time, lat, lon) float32 ...\n",
" non_time_variable_1 (lat, lon) float64 ..."
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def dset():\n",
" \"\"\"Generate a simple test dataset\"\"\"\n",
" \n",
" start_date = np.array([0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334], dtype=np.float64)\n",
" start_date = np.append(start_date, start_date + 365)\n",
" end_date = np.array([31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365], dtype=np.float64)\n",
" end_date = np.append(end_date, end_date + 365)\n",
"\n",
" ds = xr.Dataset(coords={'time': 24, 'lat': 2, 'lon': 2, 'd2': 2})\n",
" ds['time'] = xr.DataArray(end_date, dims='time')\n",
" ds['lat'] = xr.DataArray([0, 1], dims='lat')\n",
" ds['lon'] = xr.DataArray([0, 1], dims='lon')\n",
" ds['d2'] = xr.DataArray([0, 1], dims='d2')\n",
" ds['time_bound'] = xr.DataArray(\n",
" np.array([start_date, end_date]).transpose(), dims=['time', 'd2']\n",
" )\n",
" ds['variable_1'] = xr.DataArray(\n",
" np.append(\n",
" np.zeros([12, 2, 2], dtype='float32'), np.ones([12, 2, 2], dtype='float32'), axis=0\n",
" ),\n",
" dims=['time', 'lat', 'lon'],\n",
" )\n",
" ds.variable_1.attrs['description'] = 'All zeroes for year 1, all ones for year 2'\n",
" \n",
" ds['variable_2'] = xr.DataArray(\n",
" np.append(\n",
" np.ones([12, 2, 2], dtype='float32'), np.zeros([12, 2, 2], dtype='float32'), axis=0\n",
" ),\n",
" dims=['time', 'lat', 'lon'],\n",
" )\n",
" ds.variable_2.attrs['description'] = 'All ones for year 1, all zeroes for year 2'\n",
" \n",
" ds['non_time_variable_1'] = xr.DataArray(np.ones((2, 2)), dims=['lat', 'lon'])\n",
" \n",
" ds.time.attrs['units'] = 'days since 0001-01-01 00:00:00'\n",
" ds.time.attrs['calendar'] = 'noleap'\n",
" ds.time.attrs['bounds'] = 'time_bound'\n",
"\n",
" return xr.decode_cf(ds.copy(True))\n",
"\n",
"ds = dset()\n",
"ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform the computation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>&lt;xarray.Dataset&gt;\n",
"Dimensions: (d2: 2, lat: 2, lon: 2, time: 2)\n",
"Coordinates:\n",
" * lat (lat) int64 0 1\n",
" * lon (lon) int64 0 1\n",
" * d2 (d2) int64 0 1\n",
" * time (time) int64 1 2\n",
"Data variables:\n",
" non_time_variable_1 (lat, lon) float64 ...\n",
" variable_1 (time, lat, lon) float64 0.0 0.0 0.0 ... 1.0 1.0 1.0\n",
" variable_2 (time, lat, lon) float64 1.0 1.0 1.0 ... 0.0 0.0 0.0</pre>"
],
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (d2: 2, lat: 2, lon: 2, time: 2)\n",
"Coordinates:\n",
" * lat (lat) int64 0 1\n",
" * lon (lon) int64 0 1\n",
" * d2 (d2) int64 0 1\n",
" * time (time) int64 1 2\n",
"Data variables:\n",
" non_time_variable_1 (lat, lon) float64 ...\n",
" variable_1 (time, lat, lon) float64 0.0 0.0 0.0 ... 1.0 1.0 1.0\n",
" variable_2 (time, lat, lon) float64 1.0 1.0 1.0 ... 0.0 0.0 0.0"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_ann = calc_ann_mean(center_time(ds))\n",
"ds_ann"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:miniconda3-co2-hole]",
"language": "python",
"name": "conda-env-miniconda3-co2-hole-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment