Source code for xsdba._adjustment

"""
Adjustment Algorithms
=====================

This file defines the different steps, to be wrapped into the Adjustment objects.
"""

from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence

import numpy as np
import xarray as xr

from . import nbutils as nbu
from . import utils as u
from ._processing import _adapt_freq
from .base import Grouper, map_blocks, map_groups
from .detrending import PolyDetrend
from .options import set_options
from .processing import (
    escore,
    jitter_over_thresh,
    jitter_under_thresh,
    reordering,
    standardize,
)
from .units import convert_units_to
from .utils import _fitfunc_1d


def _adapt_freq_preprocess(ds, adapt_freq_thresh, group: Grouper | None, dim: str | None):
    if adapt_freq_thresh is None:
        return ds
    if (group is None) ^ (dim is None) is False:
        raise ValueError("Either `group` or `dim` must be None.")
    thresh = convert_units_to(adapt_freq_thresh, ds.sim)
    if group:
        out = _adapt_freq(ds, group=group, thresh=thresh).rename({"sim_ad": "sim"})
    else:
        out = _adapt_freq.func(ds, dim=dim, thresh=thresh).rename({"sim_ad": "sim"})
    ds = ds.assign({v: out[v] for v in out.data_vars})
    # `P0_ref` and `P0_hist` give enough information
    ds = ds.drop_vars("dP0")
    return ds


def _preprocess_dataset(
    ds: xr.Dataset,
    dim: str | list,
    adapt_freq_thresh: str | None = None,
    jitter_under_thresh_value: str | None = None,
    jitter_over_thresh_value: str | None = None,
    jitter_over_thresh_upper_bnd: str | None = None,
):
    dim = dim if isinstance(dim, list) else [dim]
    # uniformize the notation, change back at the end
    if rename_hist := ("hist" in ds):
        ds = ds.rename({"hist": "sim"})

    if jitter_under_thresh_value:
        ds["sim"] = jitter_under_thresh(ds.sim, jitter_under_thresh_value)

    if (jitter_over_thresh_value is None) ^ (jitter_over_thresh_upper_bnd is None):
        raise ValueError("`jitter_over_thresh_value` and `jitter_over_thresh_upper_bnd` must both be specified or both be `None` (default)")
    if jitter_over_thresh_value:
        ds["sim"] = jitter_over_thresh(ds.sim, jitter_over_thresh_value, jitter_over_thresh_upper_bnd)

    if adapt_freq_thresh:
        ds = _adapt_freq_preprocess(ds, adapt_freq_thresh, None, dim)

    else:
        # pick the dataset with the largest number of dimensions
        # keep add_dims on the new datasets
        dim = Grouper.filter_add_dims(dim)
        ds0 = ds.sim if len(ds.sim.dims) >= len(ds.ref.dims) else ds.ref
        dummy = xr.full_like(ds0[{d: 0 for d in dim}], np.nan)
        ds = ds.assign(P0_ref=dummy, P0_hist=dummy, pth=dummy)

    if rename_hist:
        ds = ds.rename({"sim": "hist"})

    return ds


@map_groups(
    af=[Grouper.PROP, "quantiles"],
    hist_q=[Grouper.PROP, "quantiles"],
    hist_q_raw=[Grouper.PROP, "quantiles"],
    scaling=[Grouper.PROP],
    P0_ref=[Grouper.PROP, Grouper.ADD_DIMS],
    P0_hist=[Grouper.PROP, Grouper.ADD_DIMS],
    pth=[Grouper.PROP, Grouper.ADD_DIMS],
)
def dqm_train(
    ds: xr.Dataset,
    *,
    dim: str,
    kind: str,
    quantiles: np.ndarray,
    adapt_freq_thresh: str | None = None,
    jitter_under_thresh_value: str | None = None,
    jitter_over_thresh_value: str | None = None,
    jitter_over_thresh_upper_bnd: str | None = None,
    max_tail_factor: float | None = None,
) -> xr.Dataset:
    """
    Train step on one group.

    Parameters
    ----------
    ds : xr.Dataset
        Dataset variables:
            ref : training target
            hist : training data
    dim : str
        The dimension along which to compute the quantiles.
    kind : str
        The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
    quantiles : array-like
        The quantiles to compute.
    adapt_freq_thresh : str, optional
        Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
        Default is None, meaning that frequency adaptation is not performed.
    jitter_under_thresh_value : str, optional
        Threshold under which a uniform random noise is added to values, a quantity with units.
        Default is None, meaning that jitter under thresh is not performed.
    jitter_over_thresh_value : str, optional
        Threshold above which a uniform random noise is added to values, a quantity with units.
        Default is None, meaning that jitter over thresh is not performed.
    jitter_over_thresh_upper_bnd : str, optional
        Maximum possible value for the random noise, a quantity with units.
        Default is None, meaning that jitter over thresh is not performed.
    max_tail_factor: float, optional
        If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
        of the last quantile of hist (before the preprocessing steps, stored in hist_q_raw) are not adjusted.
        We keep the input simulation with only the preprocessing steps instead.
        If None, hist_q_raw output will just be a dummy variable.

    Returns
    -------
    xr.Dataset
        The dataset containing the adjustment factors, the quantiles over the training data, and the scaling factor.

    Notes
    -----
    `jitter_over_thresh_value` and `jitter_over_thresh_upper_bnd` must be both be specified to
    use `jitter_over_thresh`, or both be None (default) to skip it.
    """
    sim_dim = Grouper.filter_dim(ds.hist, dim)
    if max_tail_factor is not None:
        # needed for  max_tail_factor in dqm_adjust
        hist_q_raw = nbu.quantile(ds.hist, quantiles, sim_dim)

    ds = _preprocess_dataset(
        ds,
        dim,
        adapt_freq_thresh,
        jitter_under_thresh_value,
        jitter_over_thresh_value,
        jitter_over_thresh_upper_bnd,
    )
    # Ensures extra dimensions are only aggregated in datasets that have them
    ref_dim = Grouper.filter_dim(ds.ref, dim)
    # ds.hist might have been broadcasted in preprocess, so `sim_dim` must be re-computed
    sim_dim = Grouper.filter_dim(ds.hist, dim)
    refn = u.apply_correction(ds.ref, u.invert(ds.ref.mean(ref_dim), kind), kind)
    histn = u.apply_correction(ds.hist, u.invert(ds.hist.mean(sim_dim), kind), kind)

    ref_q = nbu.quantile(refn, quantiles, ref_dim)
    hist_q = nbu.quantile(histn, quantiles, sim_dim)
    if max_tail_factor is None:
        # make a dummy variable to keep the same output structure
        hist_q_raw = xr.full_like(hist_q, np.nan)

    af = u.get_correction(hist_q, ref_q, kind)
    mu_ref = ds.ref.mean(ref_dim)
    mu_hist = ds.hist.mean(sim_dim)
    scaling = u.get_correction(mu_hist, mu_ref, kind=kind)
    return xr.Dataset(
        data_vars={
            "af": af,
            "hist_q": hist_q,
            "hist_q_raw": hist_q_raw,
            "scaling": scaling,
            "P0_ref": ds.P0_ref,
            "P0_hist": ds.P0_hist,
            "pth": ds.pth,
        }
    )


@map_groups(
    af=[Grouper.PROP, "quantiles"],
    hist_q=[Grouper.PROP, "quantiles"],
    hist_q_raw=[Grouper.PROP, "quantiles"],
    P0_ref=[Grouper.PROP, Grouper.ADD_DIMS],
    P0_hist=[Grouper.PROP, Grouper.ADD_DIMS],
    pth=[Grouper.PROP, Grouper.ADD_DIMS],
)
def eqm_train(
    ds: xr.Dataset,
    *,
    dim: str,
    kind: str,
    quantiles: np.ndarray,
    adapt_freq_thresh: str | None = None,
    jitter_under_thresh_value: str | None = None,
    jitter_over_thresh_value: str | None = None,
    jitter_over_thresh_upper_bnd: str | None = None,
    max_tail_factor: float | None = None,
) -> xr.Dataset:
    """
    EQM: Train step on one group.

    Parameters
    ----------
    ds : xr.Dataset
        Dataset variables:
            ref : training target
            hist : training data
    dim : str
        The dimension along which to compute the quantiles.
    kind : str
        The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
    quantiles : array-like
        The quantiles to compute.
    adapt_freq_thresh : str, optional
        Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
        Default is None, meaning that frequency adaptation is not performed.
    jitter_under_thresh_value : str, optional
        Threshold under which a uniform random noise is added to values, a quantity with units.
        Default is None, meaning that jitter under thresh is not performed.
    jitter_over_thresh_upper_bnd : str, optional
        Maximum possible value for the random noise, a quantity with units.
        Default is None, meaning that jitter over thresh is not performed.
    max_tail_factor: float, optional
        If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
        of the last quantile of hist (before the preprocessing steps, stored in hist_q_raw) are not adjusted.
        We keep the input simulation with only the preprocessing steps instead.
        If None, hist_q_raw output will just be a dummy variable.

    Returns
    -------
    xr.Dataset
        The dataset containing the adjustment factors and the quantiles over the training data.

    Notes
    -----
    `jitter_over_thresh_value` and `jitter_over_thresh_upper_bnd` must be both be specified to
    use `jitter_over_thresh`, or both be None (default) to skip it.
    """
    sim_dim = Grouper.filter_dim(ds.hist, dim)
    if max_tail_factor is not None:
        # needed for  max_tail_factor in dqm_adjust
        hist_q_raw = nbu.quantile(ds.hist, quantiles, sim_dim)

    ds = _preprocess_dataset(
        ds,
        dim,
        adapt_freq_thresh,
        jitter_under_thresh_value,
        jitter_over_thresh_value,
        jitter_over_thresh_upper_bnd,
    )

    # Ensures extra dimensions are only aggregated in datasets that have them
    ref_dim = Grouper.filter_dim(ds.ref, dim)
    # ds.hist might have been broadcasted in preprocess, so `sim_dim` must be re-computed
    sim_dim = Grouper.filter_dim(ds.hist, dim)
    ref_q = nbu.quantile(ds.ref, quantiles, ref_dim)
    hist_q = nbu.quantile(ds.hist, quantiles, sim_dim)
    if max_tail_factor is None:
        # make a dummy variable to keep the same output structure
        hist_q_raw = xr.full_like(hist_q, np.nan)
    af = u.get_correction(hist_q, ref_q, kind)
    return xr.Dataset(
        data_vars={
            "af": af,
            "hist_q": hist_q,
            "P0_ref": ds.P0_ref,
            "P0_hist": ds.P0_hist,
            "pth": ds.pth,
            "hist_q_raw": hist_q_raw,
        }
    )


def _npdft_train(ref, hist, rots, quantiles, method, extrap, n_escore, standardize):
    r"""
    Npdf transform to correct a source `hist` into target `ref`.

    Perform a rotation, bias correct `hist` into `ref` with QuantileDeltaMapping, and rotate back.
    Do this iteratively over all rotations `rots` and conserve adjustment factors `af_q` in each iteration.

    Notes
    -----
    This function expects numpy inputs. The input arrays `ref,hist` are expected to be 2-dimensional arrays with shape:
    `(len(nfeature), len(time))`, where `nfeature` is the dimension which is mixed by the multivariate bias adjustment
    (e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape
    `(len(iterations), len(nfeature), len(nfeature))`.
    """
    if standardize:
        ref = (ref - np.nanmean(ref, axis=-1, keepdims=True)) / (np.nanstd(ref, axis=-1, keepdims=True))
        hist = (hist - np.nanmean(hist, axis=-1, keepdims=True)) / (np.nanstd(hist, axis=-1, keepdims=True))
    af_q = np.zeros((len(rots), ref.shape[0], len(quantiles)))
    escores = np.zeros(len(rots)) * np.nan
    if n_escore > 0:
        ref_step, hist_step = (int(np.ceil(arr.shape[1] / n_escore)) for arr in [ref, hist])
    for ii, _rot in enumerate(rots):
        rot = _rot if ii == 0 else _rot @ rots[ii - 1].T
        ref, hist = rot @ ref, rot @ hist
        # loop over variables
        for iv in range(ref.shape[0]):
            ref_q, hist_q = nbu._quantile(ref[iv], quantiles), nbu._quantile(hist[iv], quantiles)
            af_q[ii, iv] = ref_q - hist_q
            af = u._interp_on_quantiles_1D(
                u._rank_bn(hist[iv]),
                quantiles,
                af_q[ii, iv],
                method=method,
                extrap=extrap,
            )
            hist[iv] = hist[iv] + af
        if n_escore > 0:
            escores[ii] = nbu._escore(ref[:, ::ref_step], hist[:, ::hist_step])
    hist = rots[-1].T @ hist  # FIXME: This variable is unused.
    return af_q, escores


[docs] def mbcn_train( ds: xr.Dataset, rot_matrices: xr.DataArray, pts_dims: Sequence[str], quantiles: np.ndarray, gw_idxs: xr.DataArray, interp: str, extrapolation: str, n_escore: int, ) -> xr.Dataset: """ Npdf transform training. Adjusting factors obtained for each rotation in the npdf transform and conserved to be applied in the adjusting step in :py:func:`mcbn_adjust`. Parameters ---------- ds : xr.Dataset Dataset variables: ref : training target hist : training data rot_matrices : xr.DataArray The rotation matrices as a 3D array ('iterations', <pts_dims[0]>, <pts_dims[1]>), with shape (n_iter, <N>, <N>). pts_dims : sequence of str The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which is the normal case when using :py:func:`xsdba.stack_variables`, and "multivar_prime". quantiles : array-like The quantiles to compute. gw_idxs : xr.DataArray Indices of the times in each windowed time group. interp : str The interpolation method to use. extrapolation : str The extrapolation method to use. n_escore : int Number of elements to include in the e_score test (0 for all, < 0 to skip). Returns ------- xr.Dataset The dataset containing the adjustment factors and the quantiles over the training data (only the npdf transform of mbcn). """ # unpack data ref = ds.ref hist = ds.hist gr_dim = gw_idxs.attrs["group_dim"] # npdf training core af_q_l = [] escores_l = [] # loop over time blocks for ib in range(gw_idxs[gr_dim].size): # indices in a given time block indices = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values ind = indices[indices >= 0] # npdft training : multiple rotations on standardized datasets # keep track of adjustment factors in each rotation for later use af_q, escores = xr.apply_ufunc( _npdft_train, ref[{"time": ind}], hist[{"time": ind}], rot_matrices, quantiles, input_core_dims=[ [pts_dims[0], "time"], [pts_dims[0], "time"], ["iterations", pts_dims[1], pts_dims[0]], ["quantiles"], ], output_core_dims=[ ["iterations", pts_dims[1], "quantiles"], ["iterations"], ], dask="parallelized", output_dtypes=[hist.dtype, hist.dtype], kwargs={ "method": interp, "extrap": extrapolation, "n_escore": n_escore, "standardize": True, }, vectorize=True, ) af_q_l.append(af_q.expand_dims({gr_dim: [ib]})) escores_l.append(escores.expand_dims({gr_dim: [ib]})) af_q = xr.concat(af_q_l, dim=gr_dim) escores = xr.concat(escores_l, dim=gr_dim) out = xr.Dataset({"af_q": af_q, "escores": escores}).assign_coords({"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values}) return out
def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap): """ Npdf transform adjusting. Adjusting factors `af_q` obtained in the training step are applied on the simulated data `sim` at each iterated rotation, see :py:func:`_npdft_train`. This function expects numpy inputs. `sim` can be a 2-d array with shape: `(len(nfeature), len(time))`, or a 3-d array with shape: `(len(period), len(nfeature), len(time))`, allowing to adjust multiple climatological periods all at once. `nfeature` is the dimension which is mixed by the multivariate bias adjustment (e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape `(len(iterations), len(nfeature), len(nfeature))`. """ # add dummy dim if period_dim absent to uniformize the function below # This could be done at higher level, not sure where is best if dummy_dim_added := (len(sim.shape) == 2): sim = sim[:, np.newaxis, :] # adjust npdft for ii, _rot in enumerate(rots): rot = _rot if ii == 0 else _rot @ rots[ii - 1].T sim = np.einsum("ij,j...->i...", rot, sim) # loop over variables for iv in range(sim.shape[0]): af = u._interp_on_quantiles_1D_multi( u._rank_bn(sim[iv], axis=-1), quantiles, af_q[ii, iv], method=method, extrap=extrap, ) sim[iv] = sim[iv] + af rot = rots[-1].T sim = np.einsum("ij,j...->i...", rot, sim) if dummy_dim_added: sim = sim[:, 0, :] return sim
[docs] def mbcn_adjust( ref: xr.DataArray, hist: xr.DataArray, sim: xr.DataArray, ds: xr.Dataset, g_idxs: xr.DataArray, gw_idxs: xr.DataArray, pts_dims: tuple[str, str], interp: str, extrapolation: str, base: Callable, base_kws_vars: dict, adj_kws: dict, period_dim: str | None, ) -> xr.Dataset: """ Perform the adjustment portion MBCn multivariate bias correction technique. The function :py:func:`mbcn_train` pre-computes the adjustment factors for each rotation in the npdf portion of the MBCn algorithm. The rest of adjustment is performed here in `mbcn_adjust``. Parameters ---------- ref : xr.DataArray training target. hist : xr.DataArray training data. sim : xr.DataArray data to adjust (stacked with multivariate dimension). g_idxs : xr.DataArray Indices of the times in each time group. gw_idxs: xr.DataArray Indices of the times in each windowed time group. ds : xr.Dataset Dataset variables: rot_matrices : Rotation matrices used in the training step. af_q : Adjustment factors obtained in the training step for the npdf transform pts_dims : [str, str] The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which is the normal case when using :py:func:`xsdba.stack_variables`, and "multivar_prime". interp : str Interpolation method for the npdf transform (same as in the training step). extrapolation : str Extrapolation method for the npdf transform (same as in the training step). base : BaseAdjustment Bias-adjustment class used for the univariate bias correction. base_kws_vars : Dict Options for univariate training for the scenario that is reordered with the output of npdf transform. The arguments are those expected by TrainAdjust classes along with - kinds : Dict of correction kinds for each variable (e.g. {"pr":"*", "tasmax":"+"}). adj_kws : Dict Options for univariate adjust for the scenario that is reordered with the output of npdf transform. period_dim : str, optional Name of the period dimension used when stacking time periods of `sim` using :py:func:`xsdba.stack_periods`. If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously. This should be more performant, but also more memory intensive. Defaults to `None`: No optimization will be attempted. Returns ------- xr.Dataset The adjusted data. """ # unpacking training parameters rot_matrices = ds.rot_matrices af_q = ds.af_q quantiles = af_q.quantiles gr_dim = gw_idxs.attrs["group_dim"] win = gw_idxs.attrs["group"][1] # this way of handling was letting open the possibility to perform # interpolation for multiple periods in the simulation all at once # in principle, avoiding redundancy. Need to test this on small data # to confirm it works, and on big data to check performance. dims = ["time"] if period_dim is None else [period_dim, "time"] # mbcn core scen_mbcn = xr.zeros_like(sim) for ib in range(gw_idxs[gr_dim].size): # indices in a given time block (with and without the window) indices_gw = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values ind_gw = indices_gw[indices_gw >= 0] indices_g = g_idxs[{gr_dim: ib}].fillna(-1).astype(int).values ind_g = indices_g[indices_g >= 0] # 1. univariate adjustment of sim -> scen # the kind may differ depending on the variables scen_block = xr.zeros_like(sim[{"time": ind_gw}]) for iv, v in enumerate(sim[pts_dims[0]].values): sl = {"time": ind_gw, pts_dims[0]: iv} with set_options(extra_output=False): ADJ = base.train(ref[sl], hist[sl], **base_kws_vars[v], skip_input_checks=True) scen_block[{pts_dims[0]: iv}] = ADJ.adjust(sim[sl], **adj_kws, skip_input_checks=True) # 2. npdft adjustment of sim npdft_block = xr.apply_ufunc( _npdft_adjust, standardize(sim[{"time": ind_gw}].copy(), dim="time")[0], af_q[{gr_dim: ib}], rot_matrices, quantiles, input_core_dims=[ [pts_dims[0]] + dims, ["iterations", pts_dims[1], "quantiles"], ["iterations", pts_dims[1], pts_dims[0]], ["quantiles"], ], output_core_dims=[ [pts_dims[0]] + dims, ], dask="parallelized", output_dtypes=[sim.dtype], kwargs={"method": interp, "extrap": extrapolation}, vectorize=True, ) # 3. reorder scen according to npdft results reordered = reordering(ref=npdft_block, sim=scen_block) if win > 1: # keep central value of window (intersecting indices in gw_idxs and g_idxs) scen_mbcn[{"time": ind_g}] = reordered[{"time": np.isin(ind_gw, ind_g)}] else: scen_mbcn[{"time": ind_g}] = reordered return scen_mbcn.to_dataset(name="scen")
@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[]) def qm_adjust( ds: xr.Dataset, *, group: Grouper, interp: str, extrapolation: str, kind: str, adapt_freq_thresh: str | None = None, max_tail_factor: int | None = None, ) -> xr.Dataset: """ QM (DQM and EQM): Adjust step on one block. Parameters ---------- ds : xr.Dataset Dataset variables: af : Adjustment factors hist_q : Quantiles over the training data sim : Data to adjust. P0_ref (optional) : Proportion of zeroes in the reference P0_hist (optional) : Proportion of zeroes in the historical period of the simulation pth (optional) : The smallest value of `hist` that was not frequency-adjusted in the training. group : Grouper The grouper object. interp : str The interpolation method to use. extrapolation : str The extrapolation method to use. kind : str The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`. adapt_freq_thresh : str, optional Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. max_tail_factor: float, optional If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value of the last quantile of hist (before the preprocessing steps) are not adjusted. We keep the input simulation with only the preprocessing steps instead. Returns ------- xr.Dataset The adjusted data. """ if adapt_freq_thresh: ds["sim"] = _adapt_freq_preprocess( ds[["sim", "P0_ref", "P0_hist", "pth"]], adapt_freq_thresh, group=Grouper(group.name), dim=None, ).sim # mask no bias adjustment, when sim is larger than n times the largest quantile in hist (without adapt freq) if max_tail_factor is not None: adaptedsim = ds["sim"].copy() last_quantile = ds["hist_q_raw"].isel({"quantiles": -1}).drop_vars("quantiles") # make last_quantile dim fit adaptedsim dim last_quantile = u.broadcast( last_quantile, adaptedsim, group=group, interp=interp if group.prop != "dayofyear" else "nearest", ) mask = adaptedsim > max_tail_factor * last_quantile af = u.interp_on_quantiles( ds.sim, ds.hist_q, ds.af, group=group, method=interp, extrapolation=extrapolation, ) scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") # apply max_tail_factor mask if max_tail_factor is not None: scen = scen.where(~mask, adaptedsim) out = scen.to_dataset() return out @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], trend=[]) def dqm_adjust( ds: xr.Dataset, *, group: Grouper, interp: str, kind: str, extrapolation: str, detrend: int | PolyDetrend, adapt_freq_thresh: str | None = None, max_tail_factor: int | None = None, ) -> xr.Dataset: """ DQM adjustment on one block. Parameters ---------- ds : xr.Dataset Dataset variables: scaling : Scaling factor between ref and hist af : Adjustment factors hist_q : Quantiles over the training data sim : Data to adjust P0_ref (optional) : Proportion of zeroes in the reference P0_hist (optional) : Proportion of zeroes in the historical period of the simulation pth (optional) : The smallest value of `hist` that was not frequency-adjusted in the training. group : Grouper The grouper object. interp : str The interpolation method to use. kind : str The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`. extrapolation : str The extrapolation method to use. detrend : int | PolyDetrend The degree of the polynomial detrending to apply. If 0, no detrending is applied. adapt_freq_thresh : str, optional Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. max_tail_factor: float, optional If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value of the last quantile of hist (before the preprocessing steps) are not adjusted. We keep the input simulation with only the preprocessing steps instead. Returns ------- xr.Dataset The adjusted data and the trend. """ if adapt_freq_thresh: ds["sim"] = _adapt_freq_preprocess( ds[["sim", "P0_ref", "P0_hist", "pth"]], adapt_freq_thresh, group=Grouper(group.name), dim=None, ).sim # mask no bias adjustment, when sim is larger than n times the largest quantile in hist (without adapt freq) if max_tail_factor is not None: adaptedsim = ds["sim"].copy() last_quantile = ds["hist_q_raw"].isel({"quantiles": -1}).drop_vars("quantiles") # make last_quantile dim fit adaptedsim dim last_quantile = u.broadcast( last_quantile, adaptedsim, group=group, interp=interp if group.prop != "dayofyear" else "nearest", ) mask = adaptedsim > max_tail_factor * last_quantile scaled_sim = u.apply_correction( ds.sim, u.broadcast( ds.scaling, ds.sim, group=group, interp=interp if group.prop != "dayofyear" else "nearest", ), kind, ).assign_attrs({"units": ds.sim.units}) if isinstance(detrend, int): detrending = PolyDetrend(degree=detrend, kind=kind, group=group) else: detrending = detrend detrending = detrending.fit(scaled_sim) ds["sim"] = detrending.detrend(scaled_sim) scen = qm_adjust.func( ds, group=group, interp=interp, extrapolation=extrapolation, kind=kind, ).scen scen = detrending.retrend(scen) # apply max_tail_factor mask if max_tail_factor is not None: scen = scen.where(~mask, adaptedsim) out = xr.Dataset({"scen": scen, "trend": detrending.ds.trend}) return out @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], sim_q=[]) def qdm_adjust( ds: xr.Dataset, *, group: Grouper, interp: str, extrapolation: str, kind: str, adapt_freq_thresh: str | None = None, rank_window: bool | None = None, max_tail_factor: int | None = None, ) -> xr.Dataset: """ QDM adjustment on one block. Parameters ---------- ds : xr.Dataset Dataset variables: af : Adjustment factors hist_q : Quantiles over the training data sim : Data to adjust. group : Grouper The grouper object. interp : str The interpolation method to use. kind : str The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`. extrapolation : str The extrapolation method to use. detrend : int | PolyDetrend The degree of the polynomial detrending to apply. If 0, no detrending is applied. adapt_freq_thresh : str, optional Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. rank_window : bool, optional Whether to rank simulated values over the full grouping window. Effectively, the default is `False` which preserves legacy behavior. If `False`, ranks are computed within exact groups, e.g., a specific day of year. In `xsdba>=0.8`, this option will be deprecated in favour of honoring the grouping window (equivalent to `True`). max_tail_factor: float, optional If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value of the last quantile of hist (before the preprocessing steps) are not adjusted. We keep the input simulation with only the preprocessing steps instead. Returns ------- xr.Dataset The adjusted data. Warns ----- DeprecationWarning If rank_window is None and group size is larger than one, a warning indicates that in `xsdba>=0.8` the current behavior will change. """ if adapt_freq_thresh: ds["sim"] = _adapt_freq_preprocess( ds[["sim", "P0_ref", "P0_hist", "pth"]], adapt_freq_thresh, group=Grouper(group.name), dim=None, ).sim # mask no bias adjustment, when sim is larger than n times the largest quantile in hist (without adapt freq) if max_tail_factor is not None: adaptedsim = ds["sim"].copy() last_quantile = ds["hist_q_raw"].isel({"quantiles": -1}).drop_vars("quantiles") # make last_quantile dim fit adaptedsim dim last_quantile = u.broadcast( last_quantile, adaptedsim, group=group, interp=interp if group.prop != "dayofyear" else "nearest", ) mask = adaptedsim > max_tail_factor * last_quantile if rank_window is None: rank_window = False if group.window > 1: warnings.warn( "QDM method can now perform the adjustment step by expanding the time dimension " "with the same window as used in the training. This can already be used by setting " "`rank_window = True`. This will be the only possible behaviour in `xsdba>=0.8`. " "The current behaviour is obtained by setting `rank_window = False` and will be " "deprecated in `xsdba>=0.8`. It will still be possible to use the old behaviour by " "monkeypatching the group argument in the QDM class between the training and " "adjustment, though this behaviour is not recommended.", category=DeprecationWarning, stacklevel=2, ) sim_q = group.apply(u.rank, ds.sim, main_only=not rank_window, pct=True) af = u.interp_on_quantiles( sim_q, ds.quantiles, ds.af, group=group, method=interp, extrapolation=extrapolation, ) scen = u.apply_correction(ds.sim, af, kind) # apply max_tail_factor mask if max_tail_factor is not None: scen = scen.where(~mask, adaptedsim) return xr.Dataset({"scen": scen, "sim_q": sim_q}) @map_blocks( reduces=[Grouper.ADD_DIMS, Grouper.DIM], af=[Grouper.PROP], hist_thresh=[Grouper.PROP], ) def loci_train(ds: xr.Dataset, *, group, thresh) -> xr.Dataset: """ LOCI: Train on one block. Parameters ---------- ds : xr.Dataset Dataset variables: ref : training target hist : training data """ s_thresh = group.apply(u.map_cdf, ds.rename(hist="x", ref="y"), y_value=thresh).isel(x=0) sth = u.broadcast(s_thresh, ds.hist, group=group) ws = xr.where(ds.hist >= sth, ds.hist, np.nan) wo = xr.where(ds.ref >= thresh, ds.ref, np.nan) ms = group.apply("mean", ws, skipna=True) mo = group.apply("mean", wo, skipna=True) # Adjustment factor af = u.get_correction(ms - s_thresh, mo - thresh, u.MULTIPLICATIVE) return xr.Dataset({"af": af, "hist_thresh": s_thresh}) @map_blocks(reduces=[Grouper.PROP], scen=[]) def loci_adjust(ds: xr.Dataset, *, group, thresh, interp) -> xr.Dataset: """ LOCI: Adjust on one block. Parameters ---------- ds : xr.Dataset Dataset variables: hist_thresh : Hist's equivalent thresh from ref sim : Data to adjust """ sth = u.broadcast(ds.hist_thresh, ds.sim, group=group, interp=interp) factor = u.broadcast(ds.af, ds.sim, group=group, interp=interp) with xr.set_options(keep_attrs=True): scen: xr.DataArray = (factor * (ds.sim - sth) + thresh).clip(min=0).rename("scen") out = scen.to_dataset() return out @map_groups(af=[Grouper.PROP]) def scaling_train(ds: xr.Dataset, *, dim, kind) -> xr.Dataset: """ Scaling: Train on one group. Parameters ---------- ds : xr.Dataset Dataset variables: ref : training target hist : training data """ ref_dim = Grouper.filter_dim(ds.ref, dim) sim_dim = Grouper.filter_dim(ds.hist, dim) mhist = ds.hist.mean(sim_dim) mref = ds.ref.mean(ref_dim) af: xr.DataArray = u.get_correction(mhist, mref, kind).rename("af") out = af.to_dataset() return out @map_blocks(reduces=[Grouper.PROP], scen=[]) def scaling_adjust(ds: xr.Dataset, *, group, interp, kind) -> xr.Dataset: """ Scaling: Adjust on one block. Parameters ---------- ds : xr.Dataset Dataset variables: af : Adjustment factors. sim : Data to adjust. """ af = u.broadcast(ds.af, ds.sim, group=group, interp=interp) scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") out = scen.to_dataset() return out
[docs] def npdf_transform(ds: xr.Dataset, **kwargs) -> xr.Dataset: r""" N-pdf transform : Iterative univariate adjustment in random rotated spaces. Parameters ---------- ds : xr.Dataset Dataset variables: ref : Reference multivariate timeseries. hist : simulated timeseries on the reference period. sim : Simulated timeseries on the projected period. rot_matrices : Random rotation matrices. **kwargs pts_dim : multivariate dimension name. base : Adjustment class. base_kws : Kwargs for initialising the adjustment object. adj_kws : Kwargs of the `adjust` call. n_escore : Number of elements to include in the e_score test (0 for all, < 0 to skip). Returns ------- xr.Dataset Dataset variables: scenh : Scenario in the reference period (source `hist` transferred to target `ref` inside training). scens : Scenario in the projected period (source `sim` transferred to target `ref` outside training). escores : Index estimating the dissimilarity between `scenh` and `hist`. Notes ----- If `n_escore` is negative, `escores` will be filled with NaNs. """ ref = ds.ref.rename(time_hist="time") hist = ds.hist.rename(time_hist="time") sim = ds.sim dim = kwargs["pts_dim"] escores = [] for i, R in enumerate(ds.rot_matrices.transpose("iterations", ...)): # @ operator stands for matrix multiplication (along named dimensions): x@R = R@x # @R rotates an array defined over dimension x unto new dimension x'. x@R = x' refp = ref @ R histp = hist @ R simp = sim @ R # Perform univariate adjustment in rotated space (x') ADJ = kwargs["base"].train(refp, histp, **kwargs["base_kws"], skip_input_checks=True) scenhp = ADJ.adjust(histp, **kwargs["adj_kws"], skip_input_checks=True) scensp = ADJ.adjust(simp, **kwargs["adj_kws"], skip_input_checks=True) # Rotate back to original dimension x'@R = x # Note that x'@R is a back rotation because the matrix multiplication is now done along x' due to xarray # operating along named dimensions. # In normal linear algebra, this is equivalent to taking @R.T, the back rotation. hist = scenhp @ R sim = scensp @ R # Compute score if kwargs["n_escore"] >= 0: escores.append( escore( ref, hist, dims=(dim, "time"), N=kwargs["n_escore"], scale=True, ).expand_dims(iterations=[i]) ) if kwargs["n_escore"] >= 0: escores = xr.concat(escores, "iterations") else: # All nan, but with the proper shape. escores = (ref.isel({dim: 0, "time": 0}) * hist.isel({dim: 0, "time": 0})).expand_dims(iterations=ds.iterations) * np.nan return xr.Dataset( data_vars={ "scenh": hist.rename(time="time_hist").transpose(*ds.hist.dims), "scen": sim.transpose(*ds.sim.dims), "escores": escores, } )
def _fit_on_cluster(data, thresh, cluster_thresh, dist): """Extract clusters on 1D data and fit "dist" on the maximums.""" _, _, _, maximums = u.get_clusters_1d(data, thresh, cluster_thresh) params = list(_fitfunc_1d(maximums - thresh, dist=dist, floc=0, nparams=3, method="ML")) # We forced 0, put back thresh. params[-2] = thresh return params def _extremes_train_1d(ref, hist, ref_params, cluster_thresh, *, q_thresh, dist, N): """Train for method ExtremeValues, only for 1D input along time.""" # Fast-track, do nothing for all-nan slices if all(np.isnan(ref)) or all(np.isnan(hist)): return np.full(N, np.nan), np.full(N, np.nan), np.nan # Find quantile q_thresh thresh = (np.nanquantile(ref[ref >= cluster_thresh], q_thresh) + np.nanquantile(hist[hist >= cluster_thresh], q_thresh)) / 2 # Fit genpareto on cluster maximums on ref (if needed) and hist. if np.isnan(ref_params).all(): ref_params = _fit_on_cluster(ref, thresh, cluster_thresh, dist) hist_params = _fit_on_cluster(hist, thresh, cluster_thresh, dist) # Find probabilities of extremes according to fitted dist Px_ref = dist.cdf(ref[ref >= thresh], *ref_params) hist = hist[hist >= thresh] Px_hist = dist.cdf(hist, *hist_params) # Find common probabilities range. Pmax = min(Px_ref.max(), Px_hist.max()) Pmin = max(Px_ref.min(), Px_hist.min()) Pcommon = (Px_hist <= Pmax) & (Px_hist >= Pmin) Px_hist = Px_hist[Pcommon] # Find values of hist extremes if they followed ref's distribution. hist_in_ref = dist.ppf(Px_hist, *ref_params) # Adjustment factors, unsorted af = hist_in_ref / hist[Pcommon] # sort them in Px order, and pad to have N values. order = np.argsort(Px_hist) if af.size > N: raise ValueError( "The number of precipitations part of a cluster is larger than `q_thresh`, which " "likely indicates that `cluster_thresh` is too small for `ref` and/or `hist`, i.e." "`cluster_thresh` is still in the bulk of the distribution." ) px_hist = np.pad(Px_hist[order], ((0, N - af.size),), constant_values=np.nan) af = np.pad(af[order], ((0, N - af.size),), constant_values=np.nan) return px_hist, af, thresh @map_blocks(reduces=["time"], px_hist=["quantiles"], af=["quantiles"], thresh=[Grouper.PROP]) def extremes_train( ds: xr.Dataset, *, group: Grouper, q_thresh: float, dist, quantiles: np.ndarray, ) -> xr.Dataset: """ Train extremes for a given variable series. Parameters ---------- ds : xr.Dataset Dataset containing the reference and historical data, and cluster thresholds. group : Grouper The grouper object. q_thresh : float The quantile threshold to use. dist : Any The distribution to fit. quantiles : array-like The quantiles to compute. Returns ------- xr.Dataset The dataset containing the quantiles, the adjustment factors, and the threshold. """ px_hist, af, thresh = xr.apply_ufunc( _extremes_train_1d, ds.ref, ds.hist, ds.ref_params or np.nan, ds.cluster_thresh, input_core_dims=[("time",), ("time",), (), ()], output_core_dims=[("quantiles",), ("quantiles",), ()], vectorize=True, kwargs={ "q_thresh": q_thresh, "dist": dist, "N": len(quantiles), }, ) # Outputs of map_blocks must have dimensions. if not isinstance(thresh, xr.DataArray): thresh = xr.DataArray(thresh) thresh = thresh.expand_dims(group=[1]) return xr.Dataset( {"px_hist": px_hist, "af": af, "thresh": thresh}, coords={"quantiles": quantiles}, ) def _fit_cluster_and_cdf(data, thresh, cluster_thresh, dist): """Fit 1D cluster maximums and immediately compute CDF.""" fut_params = _fit_on_cluster(data, thresh, cluster_thresh, dist) return dist.cdf(data, *fut_params) @map_blocks(reduces=["quantiles", Grouper.PROP], scen=[]) def extremes_adjust( ds: xr.Dataset, *, group: Grouper, frac: float, power: float, dist, interp: str, extrapolation: str, ) -> xr.Dataset: """ Adjust extremes to reflect many distribution factors. Parameters ---------- ds : xr.Dataset Dataset containing the reference and historical data, and cluster thresholds. group : Grouper The grouper object. frac : float The fraction of the transition function. power : float The power of the transition function. dist : Any The distribution to fit. interp : str The interpolation method to use. extrapolation : str The extrapolation method to use. Returns ------- xr.Dataset The dataset containing the adjusted data. """ # Find probabilities of extremes of fut according to its own cluster-fitted dist. px_fut = xr.apply_ufunc( _fit_cluster_and_cdf, ds.sim, ds.thresh, ds.cluster_thresh, input_core_dims=[["time"], [], []], output_core_dims=[["time"]], kwargs={"dist": dist}, vectorize=True, ) # Find factors by interpolating from hist probs to fut probs. apply them. af = u.interp_on_quantiles(px_fut, ds.px_hist, ds.af, method=interp, extrapolation=extrapolation) scen = u.apply_correction(ds.sim, af, "*") # Smooth transition function between simulation and scenario. Values below ds.thresh are kept unchanged. transition = (((ds.sim - ds.thresh).clip(0, None) / ((ds.sim.max("time")) - ds.thresh)) / frac) ** power transition = transition.clip(0, 1) adjusted: xr.DataArray = (transition * scen) + ((1 - transition) * ds.scen) out = adjusted.rename("scen").squeeze("group", drop=True).to_dataset() return out def _otc_adjust( X: np.ndarray, Y: np.ndarray, bin_width: dict | float | np.ndarray | None = None, bin_origin: dict | float | np.ndarray | None = None, num_iter_max: int | None = 100_000_000, jitter_inside_bins: bool = True, normalization: str | None = "max_distance", ): """ Optimal Transport Correction of the bias of X with respect to Y. Parameters ---------- X : np.ndarray Historical data to be corrected. Y : np.ndarray Bias correction reference, target of optimal transport. bin_width : dict or float or np.ndarray, optional Bin widths for specified dimensions. bin_origin : dict or float or np.ndarray, optional Bin origins for specified dimensions. num_iter_max : int, optional Maximum number of iterations used in the Earth-Mover_Distance (EMD) algorithm. jitter_inside_bins : bool If `False`, output points are located at the center of their bin. If `True`, a random location is picked uniformly inside their bin. Default is `True`. normalization : {'standardize', 'max_distance', 'max_value'}, optional Per-variable transformation applied before the distances are calculated in the optimal transport. Returns ------- np.ndarray Adjusted data. References ---------- :cite:cts:`robin_2021` """ # nans are removed and put back in place at the end X_og = X.copy() mask = (~np.isnan(X)).all(axis=1) X = X[mask] Y = Y[(~np.isnan(Y)).all(axis=1)] # Initialize parameters if bin_width is None: bin_width = u.bin_width_estimator([Y, X]) elif isinstance(bin_width, dict): _bin_width = u.bin_width_estimator([Y, X]) for k, v in bin_width.items(): _bin_width[k] = v bin_width = _bin_width elif isinstance(bin_width, float | int): bin_width = np.ones(X.shape[1]) * bin_width if bin_origin is None: bin_origin = np.zeros(X.shape[1]) elif isinstance(bin_origin, dict): _bin_origin = np.zeros(X.shape[1]) if bin_origin is not None: for v, k in bin_origin.items(): _bin_origin[v] = k bin_origin = _bin_origin elif isinstance(bin_origin, float | int): bin_origin = np.ones(X.shape[1]) * bin_origin num_iter_max = 100_000_000 if num_iter_max is None else num_iter_max # Get the bin positions and frequencies of X and Y, and for all Xs the bin to which they belong gridX, muX, binX = u.histogram(X, bin_width, bin_origin) gridY, muY, _ = u.histogram(Y, bin_width, bin_origin) # Compute the optimal transportation plan plan = u.optimal_transport(gridX, gridY, muX, muY, num_iter_max, normalization) gridX = np.floor((gridX - bin_origin) / bin_width) # FIXME: This variable is unused. gridY = np.floor((gridY - bin_origin) / bin_width) # regroup the indices of all the points belonging to a same bin binX_sort = np.lexsort(binX[:, ::-1].T) sorted_bins = binX[binX_sort] _, binX_start, binX_count = np.unique(sorted_bins, return_index=True, return_counts=True, axis=0) binX_start_sort = np.sort(binX_start) binX_groups = np.split(binX_sort, binX_start_sort[1:]) out = np.empty(X.shape) rng = np.random.default_rng() # The plan row corresponding to a source bin indicates its probabilities to be transported to every target bin for i, binX_group in enumerate(binX_groups): # Pick as much target bins for this source bin as there are points in the source bin choice = rng.choice(range(muY.size), p=plan[i, :], size=binX_count[i]) out[binX_group] = (gridY[choice] + 1 / 2) * bin_width + bin_origin if jitter_inside_bins: out += np.random.uniform(low=-bin_width / 2, high=bin_width / 2, size=out.shape) # reintroduce nans Z = X_og Z[mask] = out Z[~mask] = np.nan return Z @map_groups(scen=[Grouper.DIM]) def otc_adjust( ds: xr.Dataset, dim: list, pts_dim: str, bin_width: dict | float | None = None, bin_origin: dict | float | None = None, num_iter_max: int | None = 100_000_000, jitter_inside_bins: bool = True, adapt_freq_thresh: dict | None = None, normalization: str | None = "max_distance", ): """ Optimal Transport Correction of the bias of `hist` with respect to `ref`. Parameters ---------- ds : xr.Dataset Dataset variables: ref : training target hist : training data dim : list The dimensions defining the distribution on which optimal transport is performed. pts_dim : str The dimension defining the multivariate components of the distribution. bin_width : dict or float, optional Bin widths for specified dimensions. bin_origin : dict or float, optional Bin origins for specified dimensions. num_iter_max : int, optional Maximum number of iterations used in the Earth Mover Distance (EMD) algorithm. jitter_inside_bins : bool If `False`, output points are located at the center of their bin. If `True`, a random location is picked uniformly inside their bin. Default is `True`. adapt_freq_thresh : dict, optional Threshold for frequency adaptation per variable. normalization : {'standardize', 'max_distance', 'max_value'}, optional Per-variable transformation applied before the distances are calculated in the optimal transport. Returns ------- xr.Dataset Adjusted data. """ ref = ds.ref hist = ds.hist if adapt_freq_thresh is not None: for var, thresh in adapt_freq_thresh.items(): ds0 = xr.Dataset({"ref": ref.sel({pts_dim: var}), "sim": hist.sel({pts_dim: var})}) hist.loc[{pts_dim: var}] = _preprocess_dataset(ds0, adapt_freq_thresh=thresh).sim ref_dim = Grouper.filter_dim(ref, dim) ref_map = {d: f"ref_{d}" for d in ref_dim} ref = ref.rename(ref_map).stack(dim_ref=ref_map.values()).dropna(dim="dim_ref") sim_dim = Grouper.filter_dim(hist, dim) hist = hist.stack(dim_hist=sim_dim).dropna(dim="dim_hist") if isinstance(bin_width, dict): bin_width = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_width.items()} if isinstance(bin_origin, dict): bin_origin = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_origin.items()} scen = xr.apply_ufunc( _otc_adjust, hist, ref, kwargs={ "bin_width": bin_width, "bin_origin": bin_origin, "num_iter_max": num_iter_max, "jitter_inside_bins": jitter_inside_bins, "normalization": normalization, }, input_core_dims=[["dim_hist", pts_dim], ["dim_ref", pts_dim]], output_core_dims=[["dim_hist", pts_dim]], keep_attrs=True, vectorize=True, ) scen = scen.unstack().rename("scen") return scen.to_dataset() def _dotc_adjust( X1: np.ndarray, Y0: np.ndarray, X0: np.ndarray, bin_width: dict | float | None = None, bin_origin: dict | float | None = None, num_iter_max: int | None = 100_000_000, cov_factor: str | None = "std", jitter_inside_bins: bool = True, kind: dict | None = None, normalization: str | None = "max_distance", ): """ Dynamical Optimal Transport Correction of the bias of X with respect to Y. Parameters ---------- X1 : np.ndarray Simulation data to adjust. Y0 : np.ndarray Bias correction reference. X0 : np.ndarray Historical simulation data. bin_width : dict or float, optional Bin widths for specified dimensions. bin_origin : dict or float, optional Bin origins for specified dimensions. num_iter_max : int, optional Maximum number of iterations used in the Earth Mover Distance (EMD) algorithm. cov_factor : str, optional Rescaling factor. jitter_inside_bins : bool If `False`, output points are located at the center of their bin. If `True`, a random location is picked uniformly inside their bin. Default is `True`. kind : dict, optional Keys are variable names and values are adjustment kinds, either additive or multiplicative. Unspecified dimensions are treated as "+". normalization : {'standardize', 'max_distance', 'max_value'}, optional Per-variable transformation applied before the distances are calculated in the optimal transport. Returns ------- np.ndarray Adjusted data. References ---------- :cite:cts:`robin_2021` """ # nans are removed and put back in place at the end X1_og = X1.copy() mask = ~np.isnan(X1).any(axis=1) X1 = X1[mask] X0 = X0[~np.isnan(X0).any(axis=1)] Y0 = Y0[~np.isnan(Y0).any(axis=1)] # Initialize parameters if isinstance(bin_width, dict): _bin_width = u.bin_width_estimator([Y0, X0, X1]) for v, k in bin_width.items(): _bin_width[v] = k bin_width = _bin_width elif isinstance(bin_width, float | int): bin_width = np.ones(X0.shape[1]) * bin_width if isinstance(bin_origin, dict): _bin_origin = np.zeros(X0.shape[1]) for v, k in bin_origin.items(): _bin_origin[v] = k bin_origin = _bin_origin elif isinstance(bin_origin, float | int): bin_origin = np.ones(X0.shape[1]) * bin_origin # Map ref to hist yX0 = _otc_adjust( Y0, X0, bin_width=bin_width, bin_origin=bin_origin, num_iter_max=num_iter_max, jitter_inside_bins=False, normalization=normalization, ) # Map hist to sim yX1 = _otc_adjust( yX0, X1, bin_width=bin_width, bin_origin=bin_origin, num_iter_max=num_iter_max, jitter_inside_bins=False, normalization=normalization, ) # Temporal evolution motion = np.empty(yX0.shape) for j in range(yX0.shape[1]): if kind is not None and j in kind.keys() and kind[j] == "*": motion[:, j] = yX1[:, j] / yX0[:, j] else: motion[:, j] = yX1[:, j] - yX0[:, j] # Apply a variance dependent rescaling factor if cov_factor == "cholesky": fact0 = u.eps_cholesky(np.cov(Y0, rowvar=False)) fact1 = u.eps_cholesky(np.cov(X0, rowvar=False)) motion = (fact0 @ np.linalg.inv(fact1) @ motion.T).T elif cov_factor == "std": fact0 = np.std(Y0, axis=0) fact1 = np.std(X0, axis=0) motion = motion @ np.diag(fact0 / fact1) # Apply the evolution to ref Y1 = np.empty(yX0.shape) for j in range(yX0.shape[1]): if kind is not None and j in kind.keys() and kind[j] == "*": Y1[:, j] = Y0[:, j] * motion[:, j] else: Y1[:, j] = Y0[:, j] + motion[:, j] # Map sim to the evolution of ref out = _otc_adjust( X1, Y1, bin_width=bin_width, bin_origin=bin_origin, num_iter_max=num_iter_max, jitter_inside_bins=jitter_inside_bins, normalization=normalization, ) # reintroduce nans Z1 = X1_og Z1[mask] = out Z1[~mask] = np.nan return Z1 @map_groups(scen=[Grouper.DIM]) def dotc_adjust( ds: xr.Dataset, dim: list, pts_dim: str, bin_width: dict | float | None = None, bin_origin: dict | float | None = None, num_iter_max: int | None = 100_000_000, cov_factor: str | None = "std", jitter_inside_bins: bool = True, kind: dict | None = None, adapt_freq_thresh: dict | None = None, normalization: str | None = "max_distance", ): """ Dynamical Optimal Transport Correction of the bias of X with respect to Y. Parameters ---------- ds : xr.Dataset Dataset variables: ref : training target hist : training data sim : simulated data dim : list The dimensions defining the distribution on which optimal transport is performed. pts_dim : str The dimension defining the multivariate components of the distribution. bin_width : dict or float, optional Bin widths for specified dimensions. bin_origin : dict or float, optional Bin origins for specified dimensions. num_iter_max : int, optional Maximum number of iterations used in the Earth Mover Distance (EMD) algorithm. cov_factor : str, optional Rescaling factor. jitter_inside_bins : bool If `False`, output points are located at the center of their bin. If `True`, a random location is picked uniformly inside their bin. Default is `True`. kind : dict, optional Keys are variable names and values are adjustment kinds, either additive or multiplicative. Unspecified dimensions are treated as "+". adapt_freq_thresh : dict, optional Threshold for frequency adaptation per variable. normalization : {'standardize', 'max_distance', 'max_value'}, optional Per-variable transformation applied before the distances are calculated in the optimal transport. Returns ------- xr.Dataset Adjusted data. """ hist = ds.hist sim = ds.sim ref = ds.ref if adapt_freq_thresh is not None: for var, thresh in adapt_freq_thresh.items(): if thresh is not None: ds0 = xr.Dataset({"ref": ref.sel({pts_dim: var}), "sim": hist.sel({pts_dim: var})}) # add the `dP0, P0_ref, P0_hist, pth` datasets ds0 = _preprocess_dataset(ds0, dim=dim, adapt_freq_thresh=thresh) hist.loc[{pts_dim: var}] = ds0.sim ds0["sim"] = sim.loc[{pts_dim: var}] # remove the `ref` dataset since we already have `P0_ref` and other datasets ds0 = ds0.drop("ref") sim.loc[{pts_dim: var}] = _preprocess_dataset(ds0, dim=dim, adapt_freq_thresh=thresh).sim # Drop data added by map_blocks and prepare for apply_ufunc sim_dim = Grouper.filter_dim(sim, dim) hist_map = {d: f"hist_{d}" for d in sim_dim} hist = hist.rename(hist_map).stack(dim_hist=hist_map.values()) ref_dim = Grouper.filter_dim(ref, dim) ref_map = {d: f"ref_{d}" for d in ref_dim} ref = ref.rename(ref_map).stack(dim_ref=ref_map.values()) sim = sim.stack(dim_sim=sim_dim) if kind is not None: kind = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in kind.items()} if isinstance(bin_width, dict): bin_width = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_width.items()} if isinstance(bin_origin, dict): bin_origin = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_origin.items()} scen = xr.apply_ufunc( _dotc_adjust, sim, ref, hist, kwargs={ "bin_width": bin_width, "bin_origin": bin_origin, "num_iter_max": num_iter_max, "cov_factor": cov_factor, "jitter_inside_bins": jitter_inside_bins, "kind": kind, "normalization": normalization, }, input_core_dims=[ ["dim_sim", pts_dim], ["dim_ref", pts_dim], ["dim_hist", pts_dim], ], output_core_dims=[["dim_sim", pts_dim]], keep_attrs=True, vectorize=True, ) scen = scen.unstack().rename("scen") return scen.to_dataset()