Source code for xsdba.base

"""
Base Classes and Developer Tools
================================
"""

from __future__ import annotations
import operator
from collections import UserDict
from collections.abc import Callable, Sequence
from inspect import _empty, signature

import cftime
import dask.array as dsk
import jsonpickle
import numpy as np
import pandas as pd
import xarray as xr
from boltons.funcutils import wraps
from xarray.core import dtypes


# TODO : Redistributes some functions in existing/new scripts


# ## Base class for the sdba module
[docs] class Parametrizable(UserDict): """ Helper base class resembling a dictionary. This object is _completely_ defined by the content of its internal dictionary, accessible through item access (`self['attr']`) or in `self.parameters`. When serializing and restoring this object, only members of that internal dict are preserved. All other attributes set directly with `self.attr = value` will not be preserved upon serialization and restoration of the object with `[json]pickle` dictionary. Other variables set with `self.var = data` will be lost in the serialization process. This class is best serialized and restored with `jsonpickle`. """ _repr_hide_params = [] def __getstate__(self): """For (json)pickle, a Parametrizable should be defined by its internal dict only.""" return self.data def __setstate__(self, state): """For (json)pickle, a Parametrizable in only defined by its internal dict.""" # Unpickling skips the init, so we must _set_ data, we can't just update it - it's not there yet self.data = {**state} def __getattr__(self, attr): """Get attributes.""" if attr == "data" or attr not in self.data: return self.__getattribute__(attr) return self.data[attr] @property def parameters(self) -> dict: """All parameters as a dictionary. Read-only.""" return {**self.data} def __repr__(self) -> str: """Return a string representation.""" # Get default values from the init signature defaults = { # A default value of None could mean an empty mutable object n: [p.default] if p.default is not None else [[], {}, set(), None] for n, p in signature(self.__init__).parameters.items() if p.default is not _empty } # The representation only includes the parameters with a value different from their default # and those not explicitly excluded. params = ", ".join([f"{k}={v!r}" for k, v in self.items() if k not in self._repr_hide_params and v not in defaults.get(k, [])]) return f"{self.__class__.__name__}({params})"
[docs] class ParametrizableWithDataset(Parametrizable): """Parametrizable class that also has a `ds` attribute storing a dataset.""" _attribute = "_xsdba_parameters"
[docs] @classmethod def from_dataset(cls, ds: xr.Dataset): """ Create an instance from a dataset. The dataset must have a global attribute with a name corresponding to `cls._attribute`, and that attribute must be the result of `jsonpickle.encode(object)` where object is of the same type as this object. """ obj = jsonpickle.decode(ds.attrs[cls._attribute]) # noqa: S301 obj.set_dataset(ds) return obj
[docs] def set_dataset(self, ds: xr.Dataset) -> None: """ Store an xarray dataset in the `ds` attribute. Useful with custom object initialization or if some external processing was performed. """ self.ds = ds self.ds.attrs[self._attribute] = jsonpickle.encode(self)
# XC: calendar # TODO: remove this and use `ds[self.dim].dt.days_in_year.max().item()` when minimum xarray is 2024.09 max_doy = { "standard": 366, "gregorian": 366, "proleptic_gregorian": 366, "julian": 366, "noleap": 365, "365_day": 365, "all_leap": 366, "366_day": 366, "360_day": 360, }
[docs] class Grouper(Parametrizable): """Grouper inherited class for parameterizable classes.""" _repr_hide_params = ["dim", "prop"] # For a concise repr # Two constants for use of `map_blocks` and `map_groups`. # They provide better code readability, nothing more PROP = "<PROP>" DIM = "<DIM>" ADD_DIMS = "<ADD_DIMS>" def __init__( self, group: str, window: int = 1, add_dims: Sequence[str] | set[str] | None = None, ): """ Create the Grouper object. Parameters ---------- group : str The usual grouping name as xarray understands it. Ex: "time.month" or "time". The dimension name before the dot is the "main dimension" stored in `Grouper.dim` and the property name after is stored in `Grouper.prop`. window : int If larger than 1, a centered rolling window along the main dimension is created when grouping data. Units are the sampling frequency of the data along the main dimension. add_dims : Optional[Union[Sequence[str], str]] Additional dimensions that should be reduced in grouping operations. This behaviour is also controlled by the `main_only` parameter of the `apply` method. If any of these dimensions are absent from the DataArrays, they will be omitted. """ if group == "time" and window > 1: raise ValueError( "The group given is 'time', but the window given is greater than 1. The `group = 'time'` option " "takes the complete series, thus the concept of window is not applicable in this case. When using `group = 'time'`, " "`window=1` is expected." ) if "." in group: dim, prop = group.split(".") else: dim, prop = group, "group" # TODO : Remove this special workaround # This will only work with MBCn if group == "5D": dim = "time" if isinstance(add_dims, str): add_dims = [add_dims] add_dims = add_dims or [] super().__init__( dim=dim, add_dims=add_dims, prop=prop, name=group, window=window, )
[docs] @classmethod def from_kwargs(cls, **kwargs) -> dict[str, Grouper]: """Parameterize groups using kwargs.""" kwargs["group"] = cls( group=kwargs.pop("group"), window=kwargs.pop("window", 1), add_dims=kwargs.pop("add_dims", []), ) return kwargs
@property def freq(self): """ Format a frequency string corresponding to the group. For use with xarray's resampling functions. """ return { "group": "YS", "season": "QS-DEC", "month": "MS", "week": "W", "dayofyear": "D", }.get(self.prop, None) @property def prop_name(self): """Create a significant name for the grouping.""" return "year" if self.prop == "group" else self.prop
[docs] def get_coordinate(self, ds: xr.Dataset | None = None) -> xr.DataArray: """ Return the coordinate as in the output of group.apply. Currently, only implemented for groupings with prop == `month` or `dayofyear`. For prop == `dayfofyear`, a ds (Dataset or DataArray) can be passed to infer the max day of year from the available years and calendar. """ if self.prop == "month": return xr.DataArray(np.arange(1, 13), dims=("month",), name="month") if self.prop == "season": return xr.DataArray(["DJF", "MAM", "JJA", "SON"], dims=("season",), name="season") if self.prop == "dayofyear": if ds is not None: cal = ds.time.dt.calendar # TODO : Change this to `ds[self.dim].dt.days_in_year.max().item()` when minimum xarray is 2024.09 mdoy = max_doy[cal] else: mdoy = 365 return xr.DataArray(np.arange(1, mdoy + 1), dims="dayofyear", name="dayofyear") if self.prop == "group": return xr.DataArray([1], dims=("group",), name="group") # TODO: woups what happens when there is no group? (prop is None) raise NotImplementedError("No grouping found.")
[docs] def group( self, da: xr.DataArray | xr.Dataset | None = None, main_only: bool = False, **das: xr.DataArray, ) -> xr.core.groupby.GroupBy: # pylint: disable=no-member """ Return a xr.core.groupby.GroupBy object. More than one array can be combined to a dataset before grouping using the `das` kwargs. A new `window` dimension is added if `self.window` is larger than 1. If `Grouper.dim` is 'time', but 'prop' is None, the whole array is grouped together. When multiple arrays are passed, some of them can be grouped along the same group as self. They are broadcast, merged to the grouping dataset and regrouped in the output. """ if das: from .utils import ( # pylint: disable=cyclic-import,import-outside-toplevel broadcast, ) if da is not None: das[da.name] = da da = xr.Dataset(data_vars={name: das.pop(name) for name in list(das.keys()) if self.dim in das[name].dims}) # "Ungroup" the grouped arrays da = da.assign({name: broadcast(var, da[self.dim], group=self, interp="nearest") for name, var in das.items()}) if not main_only and self.window > 1: da = da.rolling(center=True, **{self.dim: self.window}).construct(window_dim="window") if uses_dask(da): # Rechunk. There might be padding chunks. da = da.chunk({self.dim: -1}) if self.prop == "group": group = self.get_index(da) else: group = self.name return da.groupby(group)
[docs] def get_index( self, da: xr.DataArray | xr.Dataset, interp: bool | None = None, ) -> xr.DataArray: """ Return the group index of each element along the main dimension. Parameters ---------- da : xr.DataArray or xr.Dataset The input array/dataset for which the group index is returned. It must have `Grouper.dim` as a coordinate. interp : bool, optional If True, the returned index can be used for interpolation. Only value for month grouping, where integer values represent the middle of the month, all other days are linearly interpolated in between. Returns ------- xr.DataArray The index of each element along `Grouper.dim`. If `Grouper.dim` is `time` and `Grouper.prop` is None, a uniform array of True is returned. If `Grouper.prop` is a time accessor (month, dayofyear, etc.), a numerical array is returned, with a special case of `month` and `interp=True`. If `Grouper.dim` is not `time`, the dim is simply returned. """ if self.prop == "group": if self.dim == "time": return xr.full_like(da[self.dim], 1, dtype=int).rename("group") return da[self.dim].rename("group") ind = da.indexes[self.dim] if interp and self.dim == "time": if self.prop == "month": i = ind.month - 0.5 + ind.day / ind.days_in_month elif self.prop == "season": calendar = ind.calendar if hasattr(ind, "calendar") else "standard" length_year = 360 if calendar == "360_day" else 365 + (0 if calendar == "noleap" else ind.is_leap_year) # This is assuming that seasons have the same length. The factor 1/6 comes from the fact that # the first season is shifted by 1 month the but the middle of the season is shifted in the other direction # by half a month so -(1/12-1/24)*4 = -1/6 i = ind.dayofyear / length_year * 4 - 1 / 6 elif self.prop == "dayofyear": i = ind.dayofyear else: raise ValueError(f"Interpolation is not supported for {self.dim}.{self.prop}.") else: if self.prop == "week": i = da[self.dim].copy(data=ind.isocalendar().week).astype(int) elif self.prop == "season": i = da[self.dim].copy(data=ind.month % 12 // 3) else: i = getattr(ind, self.prop) if not np.issubdtype(i.dtype, np.integer): raise ValueError(f"Index {self.name} is not of type int (rather {i.dtype}), but {self.__class__.__name__} requires integer indexes.") xi = xr.DataArray( i, dims=self.dim, coords={self.dim: da.coords[self.dim]}, name=self.dim + " group index", ) # Expand dimensions of index to match the dimensions of da # We want vectorized indexing with no broadcasting # xi = xi.broadcast_like(da) xi.name = self.prop return xi
[docs] def apply( self, func: Callable | str, da: xr.DataArray | dict[str, xr.DataArray] | xr.Dataset, main_only: bool = False, **kwargs, ) -> xr.DataArray | xr.Dataset: r""" Apply a function group-wise on DataArrays. Parameters ---------- func : Callable or str The function to apply to the groups, either a callable or a `xr.core.groupby.GroupBy` method name as a string. The function will be called as `func(group, dim=dims, **kwargs)`. See `main_only` for the behaviour of `dims`. da : xr.DataArray or dict[str, xr.DataArray] or xr.Dataset The DataArray on which to apply the function. Multiple arrays can be passed through a dictionary. A dataset will be created before grouping. main_only : bool Whether to call the function with the main dimension only (if True) or with all grouping dims (if False, default) (including the window and dimensions given through `add_dims`). The dimensions used are also written in the "group_compute_dims" attribute. If all the input arrays are missing one of the 'add_dims', it is silently omitted. **kwargs Other keyword arguments to pass to the function. Returns ------- xr.DataArray or xr.Dataset Attributes "group", "group_window" and "group_compute_dims" are added. If the function did not reduce the array: - The output is sorted along the main dimension. - The output is rechunked to match the chunks on the input If multiple inputs with differing chunking were given as inputs, the chunking with the smallest number of chunks is used. If the function reduces the array: - If there is only one group, the singleton dimension is squeezed out of the output - The output is rechunked as to have only 1 chunk along the new dimension. Notes ----- For the special case where a Dataset is returned, but only some of its variable where reduced by the grouping, xarray's `GroupBy.map` will broadcast everything back to the ungrouped dimensions. To overcome this issue, function may add a "_group_apply_reshape" attribute set to `True` on the variables that should be reduced and these will be re-grouped by calling `da.groupby(self.name).first()`. """ if isinstance(da, dict | xr.Dataset): grpd = self.group(main_only=main_only, **da) dim_chunks = min( # Get smallest chunking to rechunk if the operation is non-grouping [d.chunks[d.get_axis_num(self.dim)] for d in da.values() if uses_dask(d) and self.dim in d.dims] or [[]], # pass [[]] if no DataArrays have chunks so min doesn't fail key=len, ) else: grpd = self.group(da, main_only=main_only) # Get chunking to rechunk is the operation is non-grouping # To match the behaviour of the case above, an empty list signifies that dask is not used for the input. dim_chunks = [] if not uses_dask(da) else da.chunks[da.get_axis_num(self.dim)] if main_only: dims = self.dim else: dims = [self.dim] + [d for d in self.add_dims if d in grpd.dims] if self.window > 1: dims += ["window"] if isinstance(func, str): out = getattr(grpd, func)(dim=dims, **kwargs) else: out = grpd.map(func, dim=dims, **kwargs) # Case where the function wants to return more than one variable. # and that some have grouped dims and other have the same dimensions as the input. # In that specific case, groupby broadcasts everything back to the input's dim, copying the grouped data. if isinstance(out, xr.Dataset): for name, out_var in out.data_vars.items(): if "_group_apply_reshape" in out_var.attrs: if self.dim in out_var.dims: out[name] = self.group(out_var, main_only=True).first(skipna=False, keep_attrs=True) del out[name].attrs["_group_apply_reshape"] # Save input parameters as attributes of output DataArray. out.attrs["group"] = self.name out.attrs["group_compute_dims"] = dims out.attrs["group_window"] = self.window # On non-reducing ops, drop the constructed window if self.window > 1 and "window" in out.dims: out = out.isel(window=self.window // 2, drop=True) # If the grouped operation did not reduce the array, the result is sometimes unsorted along dim if self.dim in out.dims: out = out.sortby(self.dim) # The expected behavior for downstream methods would be to conserve chunking along dim if uses_dask(out): # or -1 in case dim_chunks is [], when no input is chunked # (only happens if the operation is chunking the output) out = out.chunk({self.dim: dim_chunks or -1}) if self.prop == "season" and self.prop in out.coords: # Special case for "DIM.season", it is often returned in alphabetical order, # but that doesn't fit the coord given in get_coordinate out = out.sel(season=np.array(["DJF", "MAM", "JJA", "SON"])) if self.prop in out.dims and uses_dask(out): # Same as above : downstream methods expect only one chunk along the group out = out.chunk({self.prop: -1}) return out
[docs] @staticmethod def filter_dim(da: xr.DataArray, dim: str | list[str]): """ Filter the dimensions to be reduced by removing those not on the variable. The first dimension is never removed as it is considered the "main" dimension and not having it is an error. This is meant to be used within a function sent to :py:meth:`Grouper.apply`, like those decorated with :py:func:`map_groups`. Parameters ---------- da: DataArray A DataArray from which we get the list of valid dimensions. dim: str or sequence of str Dimension(s) to reduce. The first one is not removed, the others are kept only if they appear on `da`. Returns ------- list of str, the filtered dimensions list """ if isinstance(dim, str): return dim return [dim[0]] + [d for d in dim[1:] if d in da.dims]
[docs] @staticmethod def filter_add_dims(dim: list[str]): """ Filter the dimensions to be reduced by removing those in `add_dims`. The first dimension is never removed as it is considered the "main" dimension and not having it is an error. "window" is also kept if present. Parameters ---------- dim: str or sequence of str Dimension(s) to reduce. Dimensions that do not conform with Grouper.DIM (add_dims) are removed. Returns ------- list of str, the filtered dimensions list """ if isinstance(dim, str): return [dim] extra_dim = list(set(dim[1:]) - {"window"}) return list(set(dim) - set(extra_dim))
[docs] def parse_group(func: Callable, kwargs=None, allow_only=None) -> Callable: """ Parse the kwargs given to a function to set the `group` arg with a Grouper object. This function can be used as a decorator, in which case the parsing and updating of the kwargs is done at call time. It can also be called with a function from which extract the default group and kwargs to update, in which case it returns the updated kwargs. If `allow_only` is given, an exception is raised when the parsed group is not within that list. """ sig = signature(func) if "group" in sig.parameters: default_group = sig.parameters["group"].default else: default_group = None def _update_kwargs(_kwargs, allowed=None): if default_group or "group" in _kwargs: _kwargs.setdefault("group", default_group) if not isinstance(_kwargs["group"], Grouper): _kwargs = Grouper.from_kwargs(**_kwargs) if allowed is not None and "group" in _kwargs and _kwargs["group"].prop not in allowed: raise ValueError(f"Grouping on {_kwargs['group'].prop_name} is not allowed for this function. Should be one of {allowed}.") return _kwargs if kwargs is not None: # Not used as a decorator return _update_kwargs(kwargs, allowed=allow_only) # else (then it's a decorator) @wraps(func) def _parse_group(*f_args, **f_kwargs): f_kwargs = _update_kwargs(f_kwargs, allowed=allow_only) return func(*f_args, **f_kwargs) return _parse_group
[docs] def duck_empty(dims: xr.DataArray.dims, sizes, dtype="float64", chunks=None) -> xr.DataArray: """Return an empty DataArray based on a numpy or dask backend, depending on the "chunks" argument.""" shape = [sizes[dim] for dim in dims] if chunks: chnks = [chunks.get(dim, (sizes[dim],)) for dim in dims] content = dsk.empty(shape, chunks=chnks, dtype=dtype) else: content = np.empty(shape, dtype=dtype) return xr.DataArray(content, dims=dims)
def _decode_cf_coords(ds: xr.Dataset): """Decode coords in-place.""" crds = xr.decode_cf(ds.coords.to_dataset()) for crdname in list(ds.coords.keys()): ds[crdname] = crds[crdname] # decode_cf introduces an encoding key for the dtype, which can confuse the netCDF writer dtype = ds[crdname].encoding.get("dtype") if np.issubdtype(dtype, np.timedelta64) or np.issubdtype(dtype, np.datetime64): del ds[crdname].encoding["dtype"]
[docs] def map_blocks( # noqa: C901 reduces: Sequence[str] | None = None, **out_vars ) -> Callable: r""" Decorator for declaring functions and wrapping them into a map_blocks. Takes care of constructing the template dataset. Dimension order is not preserved. The decorated function must always have the signature: ``func(ds, **kwargs)``, where ds is a DataArray or a Dataset. It must always output a dataset matching the mapping passed to the decorator. Parameters ---------- reduces : sequence of strings Name of the dimensions that are removed by the function. **out_vars Mapping from variable names in the output to their *new* dimensions. The placeholders ``Grouper.PROP``, ``Grouper.DIM`` and ``Grouper.ADD_DIMS`` can be used to signify ``group.prop``,``group.dim`` and ``group.add_dims`` respectively. If an output keeps a dimension that another loses, that dimension name must be given in ``reduces`` and in the list of new dimensions of the first output. """ def merge_dimensions(*seqs): """Merge several dimensions lists while preserving order.""" out = seqs[0].copy() for seq in seqs[1:]: last_index = 0 for e in seq: if e in out: indx = out.index(e) if indx < last_index: raise ValueError("Dimensions order mismatch, lists are not mergeable.") last_index = indx else: out.insert(last_index + 1, e) return out # Ordered list of all added dimensions out_dims = merge_dimensions(*out_vars.values()) # List of dimensions reduced by the function. red_dims = reduces or [] def _decorator(func): # noqa: C901 # @wraps(func, hide_wrapped=True) @parse_group def _map_blocks(ds, **kwargs): # noqa: C901 if isinstance(ds, xr.Dataset): ds = ds.unify_chunks() # Get group if present group = kwargs.get("group") # Ensure group is given as it might not be in the signature of the wrapped func if {Grouper.PROP, Grouper.DIM, Grouper.ADD_DIMS}.intersection(out_dims + red_dims) and group is None: raise ValueError("Missing required `group` argument.") # Make translation dict if group is not None: placeholders = { Grouper.PROP: [group.prop], Grouper.DIM: [group.dim], Grouper.ADD_DIMS: group.add_dims, } else: placeholders = {} if group.add_dims is not None and set(group.add_dims).issubset(set(ds.dims)) is False: raise ValueError("`add_dims` argument needs to be a dimension in one of the input datasets.") # Get new dimensions (in order), translating placeholders to real names. new_dims = [] for dim in out_dims: new_dims.extend(placeholders.get(dim, [dim])) reduced_dims = [] for dim in red_dims: reduced_dims.extend(placeholders.get(dim, [dim])) if uses_dask(ds): # Use dask if any of the input is dask-backed. chunks = dict(ds.chunks) if isinstance(ds, xr.Dataset) else dict(zip(ds.dims, ds.chunks, strict=False)) badchunks = {} if group is not None: badchunks.update({dim: chunks.get(dim) for dim in group.add_dims + [group.dim] if len(chunks.get(dim, [])) > 1}) badchunks.update({dim: chunks.get(dim) for dim in reduced_dims if len(chunks.get(dim, [])) > 1}) if badchunks: raise ValueError(f"The dimension(s) over which we group, reduce or interpolate cannot be chunked ({badchunks}).") else: chunks = None # Dimensions untouched by the function. base_dims = list(set(ds.dims) - set(new_dims) - set(reduced_dims)) # All dimensions of the output data, new_dims are added at the end on purpose. all_dims = base_dims + new_dims # The coordinates of the output data. added_coords = [] coords = {} sizes = {} for dim in all_dims: if dim == group.prop: coords[group.prop] = group.get_coordinate(ds=ds) elif dim == group.dim: coords[group.dim] = ds[group.dim] elif dim in kwargs: coords[dim] = xr.DataArray(kwargs[dim], dims=(dim,), name=dim) elif dim in ds.dims: # If a dim has no coords : some sdba function will add them, so to be safe we add them right now # and note them to remove them afterwards. if dim not in ds.coords: added_coords.append(dim) ds[dim] = ds[dim] coords[dim] = ds[dim] else: raise ValueError(f"This function adds the {dim} dimension, its coordinate must be provided as a keyword argument.") sizes.update({name: crd.size for name, crd in coords.items()}) # Create the output dataset, but empty tmpl = xr.Dataset(coords=coords) if isinstance(ds, xr.Dataset): # Get largest dtype of the inputs, assign it to the output. dtype = max((da.dtype for da in ds.data_vars.values()), key=lambda d: d.itemsize) else: dtype = ds.dtype for var, dims in out_vars.items(): var_new_dims = [] for dim in dims: var_new_dims.extend(placeholders.get(dim, [dim])) # Out variables must have the base dims + new_dims dims = base_dims + var_new_dims # duck empty calls dask if chunks is not None tmpl[var] = duck_empty(dims, sizes, dtype=dtype, chunks=chunks) def _call_and_transpose_on_exit(dsblock, **f_kwargs): """Call the decorated func and transpose to ensure the same dim order as on the template.""" try: _decode_cf_coords(dsblock) func_out = func(dsblock, **f_kwargs).transpose(*all_dims) except Exception as err: raise ValueError(f"{func.__name__} failed on block with coords : {dsblock.coords}.") from err return func_out # Fancy patching for explicit dask task names _call_and_transpose_on_exit.__name__ = f"block_{func.__name__}" # Remove all auxiliary coords on both tmpl and ds extra_coords = {name: crd for name, crd in ds.coords.items() if name not in crd.dims} ds = ds.drop_vars(extra_coords.keys()) # Coords not sharing dims with `all_dims` (like scalar aux coord on reduced 1D input) are absent from tmpl tmpl = tmpl.drop_vars(extra_coords.keys(), errors="ignore") # Call out = ds.map_blocks(_call_and_transpose_on_exit, template=tmpl, kwargs=kwargs) # Add back the extra coords, but only those which have compatible dimensions (like xarray would have done) out = out.assign_coords({name: crd for name, crd in extra_coords.items() if set(crd.dims).issubset(out.dims)}) # Finally remove coords we added... 'ignore' in case they were already removed. out = out.drop_vars(added_coords, errors="ignore") return out _map_blocks.__dict__["func"] = func return _map_blocks return _decorator
[docs] def map_groups(reduces: Sequence[str] | None = None, main_only: bool = False, **out_vars) -> Callable: r""" Decorator for declaring functions acting only on groups and wrapping them into a map_blocks. This is the same as `map_blocks` but adds a call to `group.apply()` in the mapped func and the default value of `reduces` is changed. The decorated function must have the signature: ``func(ds, dim, **kwargs)``. Where ds is a DataAray or Dataset, dim is the `group.dim` (and add_dims). The `group` argument is stripped from the kwargs, but must evidently be provided in the call. Parameters ---------- reduces : sequence of str, optional Dimensions that are removed from the inputs by the function. Defaults to [Grouper.DIM, Grouper.ADD_DIMS] if main_only is False, and [Grouper.DIM] if main_only is True. See :py:func:`map_blocks`. main_only : bool Same as for :py:meth:`Grouper.apply`. **out_vars Mapping from variable names in the output to their *new* dimensions. The placeholders ``Grouper.PROP``, ``Grouper.DIM`` and ``Grouper.ADD_DIMS`` can be used to signify ``group.prop``,``group.dim`` and ``group.add_dims``, respectively. If an output keeps a dimension that another loses, that dimension name must be given in `reduces` and in the list of new dimensions of the first output. See Also -------- map_blocks """ def_reduces = [Grouper.DIM] if not main_only: def_reduces.append(Grouper.ADD_DIMS) reduces = reduces or def_reduces def _decorator(func): decorator = map_blocks(reduces=reduces, **out_vars) def _apply_on_group(dsblock, **kwargs): group = kwargs.pop("group") return group.apply(func, dsblock, main_only=main_only, **kwargs) # Fancy patching for explicit dask task names _apply_on_group.__name__ = f"group_{func.__name__}" # wraps(func, injected=['dim'], hide_wrapped=True)( wrapper = decorator(_apply_on_group) wrapper.__dict__["func"] = func return wrapper return _decorator
# XC: core.utils
[docs] def ensure_chunk_size(da: xr.DataArray, **minchunks: int) -> xr.DataArray: r""" Ensure that the input DataArray has chunks of at least the given size. If only one chunk is too small, it is merged with an adjacent chunk. If many chunks are too small, they are grouped together by merging adjacent chunks. Parameters ---------- da : xr.DataArray The input DataArray, with or without the dask backend. Does nothing when passed a non-dask array. **minchunks : dict[str, int] A kwarg mapping from dimension name to minimum chunk size. Pass -1 to force a single chunk along that dimension. Returns ------- xr.DataArray """ if not uses_dask(da): return da all_chunks = dict(zip(da.dims, da.chunks, strict=False)) chunking = {} for dim, minchunk in minchunks.items(): chunks = all_chunks[dim] if minchunk == -1 and len(chunks) > 1: # Rechunk to single chunk only if it's not already one chunking[dim] = -1 toosmall = np.array(chunks) < minchunk # Chunks that are too small if toosmall.sum() > 1: # Many chunks are too small, merge them by groups fac = np.ceil(minchunk / min(chunks)).astype(int) chunking[dim] = tuple(sum(chunks[i : i + fac]) for i in range(0, len(chunks), fac)) # Reset counter is case the last chunks are still too small chunks = chunking[dim] toosmall = np.array(chunks) < minchunk if toosmall.sum() == 1: # Only one, merge it with adjacent chunk ind = np.where(toosmall)[0][0] new_chunks = list(chunks) sml = new_chunks.pop(ind) new_chunks[max(ind - 1, 0)] += sml chunking[dim] = tuple(new_chunks) if chunking: return da.chunk(chunks=chunking) return da
# XC: core.utils
[docs] def uses_dask(*das: xr.DataArray | xr.Dataset) -> bool: r""" Evaluate whether dask is installed and array is loaded as a dask array. Parameters ---------- *das : xr.DataArray or xr.Dataset DataArrays or Datasets to check. Returns ------- bool True if any of the passed objects is using dask. """ if len(das) > 1: return any([uses_dask(da) for da in das]) da = das[0] if isinstance(da, xr.DataArray) and isinstance(da.data, dsk.Array): return True if isinstance(da, xr.Dataset) and any(isinstance(var.data, dsk.Array) for var in da.variables.values()): return True return False
# XC: core
[docs] def get_op(op: str, constrain: Sequence[str] | None = None) -> Callable: """ Get python's comparing function according to its name of representation and validate allowed usage. Parameters ---------- op : str Operator. constrain : sequence of str, optional A tuple of allowed operators. """ # XC binary_ops = {">": "gt", "<": "lt", ">=": "ge", "<=": "le", "==": "eq", "!=": "ne"} if op in binary_ops: binary_op = binary_ops[op] elif op in binary_ops.values(): binary_op = op else: raise ValueError(f"Operation `{op}` not recognized.") constraints = [] if isinstance(constrain, list | tuple | set): constraints.extend([binary_ops[c] for c in constrain]) constraints.extend(constrain) elif isinstance(constrain, str): constraints.extend([binary_ops[constrain], constrain]) if constrain: if op not in constraints: raise ValueError(f"Operation `{op}` not permitted for indice.") return getattr(operator, f"__{binary_op}__")
# XC: calendar # TODO: Do not allow this? def _interpolate_doy_calendar(source: xr.DataArray, doy_max: int, doy_min: int = 1) -> xr.DataArray: """ Interpolate from one set of dayofyear range to another. Interpolate an array defined over a `dayofyear` range (say 1 to 360) to another `dayofyear` range (say 1 to 365). Parameters ---------- source : xr.DataArray Array with `dayofyear` coordinates. doy_max : int The largest day of the year allowed by calendar. doy_min : int The smallest day of the year in the output. This parameter is necessary when the target time series does not span over a full year (e.g. JJA season). Default is 1. Returns ------- xr.DataArray Interpolated source array over coordinates spanning the target `dayofyear` range. """ if "dayofyear" not in source.coords.keys(): raise AttributeError("Source should have `dayofyear` coordinates.") # Interpolate to fill na values da = source if uses_dask(source): # interpolate_na cannot run on chunked dayofyear. da = source.chunk({"dayofyear": -1}) filled_na = da.interpolate_na(dim="dayofyear") # Interpolate to target dayofyear range filled_na.coords["dayofyear"] = np.linspace(start=doy_min, stop=doy_max, num=len(filled_na.coords["dayofyear"])) return filled_na.interp(dayofyear=range(doy_min, doy_max + 1)) # XC: calendar
[docs] def parse_offset(freq: str) -> tuple[int, str, bool, str | None]: """ Parse an offset string. Parse a frequency offset and, if needed, convert to cftime-compatible components. Parameters ---------- freq : str Frequency offset. Returns ------- multiplier : int Multiplier of the base frequency. "[n]W" is always replaced with "[7n]D", as xarray doesn't support "W" for cftime indexes. offset_base : str Base frequency. is_start_anchored : bool Whether coordinates of this frequency should correspond to the beginning of the period (`True`) or its end (`False`). Can only be False when base is Y, Q or M; in other words, xsdba assumes frequencies finer than monthly are all start-anchored. anchor : str, optional Anchor date for bases Y or Q. As xarray doesn't support "W", neither does xsdba (anchor information is lost when given). """ # Useful to raise on invalid freqs, convert Y to A and get default anchor (A, Q) offset = pd.tseries.frequencies.to_offset(freq) base, *anchor = offset.name.split("-") anchor = anchor[0] if len(anchor) > 0 else None start = ("S" in base) or (base[0] not in "AYQM") if base.endswith("S") or base.endswith("E"): base = base[:-1] mult = offset.n if base == "W": mult = 7 * mult base = "D" anchor = None return mult, base, start, anchor
# XC : calendar
[docs] def compare_offsets(freqA: str, op: str, freqB: str) -> bool: """ Compare offsets string based on their approximate length, according to a given operator. Offset are compared based on their length approximated for a period starting after 1970-01-01 00:00:00. If the offsets are from the same category (same first letter), only the multiplier prefix is compared (QS-DEC == QS-JAN, MS < 2MS). "Business" offsets are not implemented. Parameters ---------- freqA : str RHS Date offset string ('YS', '1D', 'QS-DEC', ...). op : {'<', '<=', '==', '>', '>=', '!='} Operator to use. freqB : str LHS Date offset string ('YS', '1D', 'QS-DEC', ...). Returns ------- bool Return freqA op freqB. """ # Get multiplier and base frequency t_a, b_a, _, _ = parse_offset(freqA) t_b, b_b, _, _ = parse_offset(freqB) if b_a != b_b: # Different base freq, compare length of first period after beginning of time. t = pd.date_range("1970-01-01T00:00:00.000", periods=2, freq=freqA) t_a = (t[1] - t[0]).total_seconds() t = pd.date_range("1970-01-01T00:00:00.000", periods=2, freq=freqB) t_b = (t[1] - t[0]).total_seconds() # else Same base freq, compare multiplier only. return get_op(op)(t_a, t_b)
# XC: calendar
[docs] def construct_offset(mult: int, base: str, start_anchored: bool, anchor: str | None): """ Reconstruct an offset string from its parts. Parameters ---------- mult : int The period multiplier (>= 1). base : str The base period string (one char). start_anchored : bool If True and base in [Y, Q, M], adds the "S" flag, False add "E". anchor : str, optional The month anchor of the offset. Defaults to JAN for bases YS and QS and to DEC for bases YE and QE. Returns ------- str An offset string, conformant to pandas-like naming conventions. Notes ----- This provides the mirror opposite functionality of :py:func:`parse_offset`. """ start = ("S" if start_anchored else "E") if base in "YAQM" else "" if anchor is None and base in "AQY": anchor = "JAN" if start_anchored else "DEC" return f"{mult if mult > 1 else ''}{base}{start}{'-' if anchor else ''}{anchor or ''}"
# XC: calendar # Names of calendars that have the same number of days for all years uniform_calendars = ("noleap", "all_leap", "365_day", "366_day", "360_day") # XC: calendar def _month_is_first_period_month(time, freq): """Return True if the given time is from the first month of freq.""" if isinstance(time, cftime.datetime): frq_monthly = xr.coding.cftime_offsets.to_offset("MS") frq = xr.coding.cftime_offsets.to_offset(freq) if frq_monthly.onOffset(time): return frq.onOffset(time) return frq.onOffset(frq_monthly.rollback(time)) # Pandas time = pd.Timestamp(time) frq_monthly = pd.tseries.frequencies.to_offset("MS") frq = pd.tseries.frequencies.to_offset(freq) if frq_monthly.is_on_offset(time): return frq.is_on_offset(time) return frq.is_on_offset(frq_monthly.rollback(time)) # XC: calendar # TODO: implement needed functions in stack_periods # move to processing
[docs] def stack_periods( da: xr.Dataset | xr.DataArray, window: int = 30, stride: int | None = None, min_length: int | None = None, freq: str = "YS", dim: str = "period", start: str = "1970-01-01", align_days: bool = True, pad_value=dtypes.NA, ): """ Construct a multi-period array. Stack different equal-length periods of `da` into a new 'period' dimension. This is similar to ``da.rolling(time=window).construct(dim, stride=stride)``, but adapted for arguments in terms of a base temporal frequency that might be non-uniform (years, months, etc.). It is reversible for some cases (see `stride`). A rolling-construct method will be much more performant for uniform periods (days, weeks). Parameters ---------- da : xr.Dataset or xr.DataArray An xarray object with a `time` dimension. Must have a uniform timestep length. Output might be strange if this does not use a uniform calendar (noleap, 360_day, all_leap). window : int The length of the moving window as a multiple of ``freq``. stride : int, optional At which interval to take the windows, as a multiple of ``freq``. For the operation to be reversible with :py:func:`unstack_periods`, it must divide `window` into an odd number of parts. Default is `window` (no overlap between periods). min_length : int, optional Windows shorter than this are not included in the output. Given as a multiple of ``freq``. Default is ``window`` (every window must be complete). Similar to the ``min_periods`` argument of ``da.rolling``. If ``freq`` is annual or quarterly and ``min_length == ``window``, the first period is considered complete if the first timestep is in the first month of the period. freq : str Units of ``window``, ``stride`` and ``min_length``, as a frequency string. Must be larger or equal to the data's sampling frequency. Note that this function offers an easier interface for non-uniform period (like years or months) but is much slower than a rolling-construct method. dim : str The new dimension name. start : str The `start` argument passed to :py:func:`xarray.date_range` to generate the new placeholder time coordinate. align_days : bool When True (default), an error is raised if the output would have unaligned days across periods. If `freq = 'YS'`, day-of-year alignment is checked and if `freq` is "MS" or "QS", we check day-in-month. Only uniform-calendar will pass the test for `freq='YS'`. For other frequencies, only the `360_day` calendar will work. This check is ignored if the sampling rate of the data is coarser than "D". pad_value : Any When some periods are shorter than others, this value is used to pad them at the end. Passed directly as argument ``fill_value`` to :py:func:`xarray.concat`, the default is the same as on that function. Returns ------- xr.DataArray A DataArray with a new `period` dimension and a `time` dimension with the length of the longest window. The new time coordinate has the same frequency as the input data but is generated using :py:func:`xarray.date_range` with the given `start` value. That coordinate is the same for all periods, depending on the choice of ``window`` and ``freq``, it might make sense. But for unequal periods or non-uniform calendars, it will certainly not. If ``stride`` is a divisor of ``window``, the correct timeseries can be reconstructed with :py:func:`unstack_periods`. The coordinate of `period` is the first timestep of each window. """ # Import in function to avoid cyclical imports from xsdba.units import ( # pylint: disable=import-outside-toplevel infer_sampling_units, units2str, ) stride = stride or window min_length = min_length or window if stride > window: raise ValueError(f"Stride must be less than or equal to window. Got {stride} > {window}.") srcfreq = xr.infer_freq(da.time) cal = da.time.dt.calendar use_cftime = da.time.dtype == "O" if ( # if srcfreq in ("D", "h", "min", "s", "ms", "us", "ns") # TODO: Can we remove compare_offsets, only used here compare_offsets(srcfreq, "<=", "D") and align_days and ((freq.startswith(("Y", "A")) and cal not in uniform_calendars) or (freq.startswith(("Q", "M")) and window > 1 and cal != "360_day")) ): if freq.startswith(("Y", "A")): u = "year" else: u = "month" raise ValueError( f"Stacking {window}{freq} periods will result in unaligned day-of-{u}. " f"Consider converting the calendar of your data to one with uniform {u} lengths, " "or pass `align_days=False` to disable this check." ) # Convert integer inputs to freq strings mult, *args = parse_offset(freq) # TODO: remove construct? (hard code construct-offset) win_frq = construct_offset(mult * window, *args) strd_frq = construct_offset(mult * stride, *args) minl_frq = construct_offset(mult * min_length, *args) # The same time coord as da, but with one extra element. # This way, the last window's last index is not returned as None by xarray's grouper. time2 = xr.DataArray( xr.date_range( da.time[0].item(), freq=srcfreq, calendar=cal, periods=da.time.size + 1, use_cftime=use_cftime, ), dims=("time",), name="time", ) periods = [] # longest = 0 # Iterate over strides, but recompute the full window for each stride start for strd_slc in da.resample(time=strd_frq).groups.values(): win_resamp = time2.isel(time=slice(strd_slc.start, None)).resample(time=win_frq) # Get slice for first group win_slc = list(win_resamp.groups.values())[0] if min_length < window: # If we ask for a min_length period instead is it complete ? min_resamp = time2.isel(time=slice(strd_slc.start, None)).resample(time=minl_frq) min_slc = list(min_resamp.groups.values())[0] open_ended = min_slc.stop is None else: # The end of the group slice is None if no outside-group value was found after the last element # As we added an extra step to time2, we avoid the case where a group ends exactly on the last element of ds open_ended = win_slc.stop is None if open_ended: # Too short, we got to the end break if ( strd_slc.start == 0 and parse_offset(freq)[1] in "YAQ" and min_length == window and not _month_is_first_period_month(da.time[0].item(), freq) ): # For annual or quarterly frequencies (which can be anchor-based), # if the first time is not in the first month of the first period, # then the first period is incomplete but by a fractional amount. continue periods.append( slice( strd_slc.start + win_slc.start, ((strd_slc.start + win_slc.stop) if win_slc.stop is not None else da.time.size), ) ) # Make coordinates lengths = xr.DataArray( [slc.stop - slc.start for slc in periods], dims=(dim,), attrs={"long_name": "Length of each period"}, ) longest = lengths.max().item() # Length as a pint-ready array : with proper units, but values are not usable as indexes anymore m, u = infer_sampling_units(da) lengths = lengths * m lengths.attrs["units"] = units2str(u) # Start points for each period and remember parameters for unstacking starts = xr.DataArray( [da.time[slc.start].item() for slc in periods], dims=(dim,), attrs={ "long_name": "Start of the period", # Save parameters so that we can unstack. "window": window, "stride": stride, "freq": freq, "unequal_lengths": int(len(np.unique(lengths)) > 1), }, ) # The "fake" axis that all periods share fake_time = xr.date_range(start, periods=longest, freq=srcfreq, calendar=cal, use_cftime=use_cftime) # Slice and concat along new dim. We drop the index and add a new one so that xarray can concat them together. out = xr.concat( [da.isel(time=slc).drop_vars("time").assign_coords(time=np.arange(slc.stop - slc.start)) for slc in periods], dim, join="outer", fill_value=pad_value, ) out = out.assign_coords( time=(("time",), fake_time, da.time.attrs.copy()), **{f"{dim}_length": lengths, dim: starts}, ) out.time.attrs.update(long_name="Placeholder time axis") return out
# XC: calendar
[docs] def unstack_periods(da: xr.DataArray | xr.Dataset, dim: str = "period"): """ Unstack an array constructed with :py:func:`stack_periods`. Can only work with periods stacked with a ``stride`` that divides ``window`` in an odd number of sections. When ``stride`` is smaller than ``window``, only the center-most stride of each window is kept, except for the beginning and end which are taken from the first and last windows. Parameters ---------- da : xr.DataArray As constructed by :py:func:`stack_periods`, attributes of the period coordinates must have been preserved. dim : str The period dimension name. Notes ----- The following table shows which strides are included (``o``) in the unstacked output. In this example, ``stride`` was a fifth of ``window`` and ``min_length`` was four (4) times ``stride``. The row index ``i`` the period index in the stacked dataset, columns are the stride-long section of the original timeseries. .. table:: Unstacking example with ``stride < window``. === === === === === === === === i 0 1 2 3 4 5 6 === === === === === === === === 3 x x o o 2 x x o x x 1 x x o x x 0 o o o x x === === === === === === === === """ from xsdba.units import ( # pylint: disable=import-outside-toplevel infer_sampling_units, ) try: starts = da[dim] window = starts.attrs["window"] stride = starts.attrs["stride"] freq = starts.attrs["freq"] unequal_lengths = bool(starts.attrs["unequal_lengths"]) except (AttributeError, KeyError) as err: raise ValueError(f"`unstack_periods` can't find the window, stride and freq attributes on the {dim} coordinates.") from err if unequal_lengths: try: lengths = da[f"{dim}_length"] except KeyError as err: raise ValueError(f"`unstack_periods` can't find the `{dim}_length` coordinate.") from err # Get length as number of points m, _ = infer_sampling_units(da.time) lengths = lengths // m else: # It is acceptable to lose "{dim}_length" if they were all equal lengths = xr.DataArray([da.time.size] * da[dim].size, dims=(dim,)) # Convert from the fake axis to the real one time_as_delta = da.time - da.time[0] if da.time.dtype == "O": # cftime can't add with np.timedelta64 (restriction comes from numpy which refuses to add O with m8) time_as_delta = pd.TimedeltaIndex(time_as_delta).to_pytimedelta() # this array is O, numpy complies else: # Xarray will return int when iterating over datetime values, this returns timestamps starts = pd.DatetimeIndex(starts) def _reconstruct_time(_time_as_delta, _start): times = _time_as_delta + _start return xr.DataArray(times, dims=("time",), coords={"time": times}, name="time") # Easy case: if window == stride: # just concat them all periods = [] for i, (start, length) in enumerate(zip(starts.values, lengths.values, strict=False)): real_time = _reconstruct_time(time_as_delta, start) periods.append(da.isel(**{dim: i}, drop=True).isel(time=slice(0, length)).assign_coords(time=real_time.isel(time=slice(0, length)))) return xr.concat(periods, "time") # Difficult and ambiguous case if (window / stride) % 2 != 1: raise NotImplementedError( "`unstack_periods` can't work with strides that do not divide the window into an odd number of parts." f"Got {window} / {stride} which is not an odd integer." ) # Non-ambiguous overlapping case Nwin = window // stride mid = (Nwin - 1) // 2 # index of the center window mult, *args = parse_offset(freq) strd_frq = construct_offset(mult * stride, *args) periods = [] for i, (start, length) in enumerate(zip(starts.values, lengths.values, strict=False)): real_time = _reconstruct_time(time_as_delta, start) slices = list(real_time.resample(time=strd_frq).groups.values()) if i == 0: slc = slice(slices[0].start, min(slices[mid].stop, length)) elif i == da.period.size - 1: slc = slice(slices[mid].start, min(slices[Nwin - 1].stop or length, length)) else: slc = slice(slices[mid].start, min(slices[mid].stop, length)) periods.append(da.isel(**{dim: i}, drop=True).isel(time=slc).assign_coords(time=real_time.isel(time=slc))) return xr.concat(periods, "time")