"""
Adjustment Algorithms
=====================
This file defines the different steps, to be wrapped into the Adjustment objects.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
import numpy as np
import xarray as xr
from . import nbutils as nbu
from . import utils as u
from ._processing import _adapt_freq
from .base import Grouper, map_blocks, map_groups
from .detrending import PolyDetrend
from .options import set_options
from .processing import (
escore,
jitter_over_thresh,
jitter_under_thresh,
reordering,
standardize,
)
from .units import convert_units_to
from .utils import _fitfunc_1d
def _adapt_freq_preprocess(ds, adapt_freq_thresh, group: Grouper | None, dim: str | None):
if adapt_freq_thresh is None:
return ds
if (group is None) ^ (dim is None) is False:
raise ValueError("Either `group` or `dim` must be None.")
thresh = convert_units_to(adapt_freq_thresh, ds.sim)
if group:
out = _adapt_freq(ds, group=group, thresh=thresh).rename({"sim_ad": "sim"})
else:
out = _adapt_freq.func(ds, dim=dim, thresh=thresh).rename({"sim_ad": "sim"})
ds = ds.assign({v: out[v] for v in out.data_vars})
# `P0_ref` and `P0_hist` give enough information
ds = ds.drop_vars("dP0")
return ds
def _preprocess_dataset(
ds: xr.Dataset,
dim: str | list,
adapt_freq_thresh: str | None = None,
jitter_under_thresh_value: str | None = None,
jitter_over_thresh_value: str | None = None,
jitter_over_thresh_upper_bnd: str | None = None,
):
dim = dim if isinstance(dim, list) else [dim]
# uniformize the notation, change back at the end
if rename_hist := ("hist" in ds):
ds = ds.rename({"hist": "sim"})
if jitter_under_thresh_value:
ds["sim"] = jitter_under_thresh(ds.sim, jitter_under_thresh_value)
if (jitter_over_thresh_value is None) ^ (jitter_over_thresh_upper_bnd is None):
raise ValueError("`jitter_over_thresh_value` and `jitter_over_thresh_upper_bnd` must both be specified or both be `None` (default)")
if jitter_over_thresh_value:
ds["sim"] = jitter_over_thresh(ds.sim, jitter_over_thresh_value, jitter_over_thresh_upper_bnd)
if adapt_freq_thresh:
ds = _adapt_freq_preprocess(ds, adapt_freq_thresh, None, dim)
else:
# pick the dataset with the largest number of dimensions
# keep add_dims on the new datasets
dim = Grouper.filter_add_dims(dim)
ds0 = ds.sim if len(ds.sim.dims) >= len(ds.ref.dims) else ds.ref
dummy = xr.full_like(ds0[{d: 0 for d in dim}], np.nan)
ds = ds.assign(P0_ref=dummy, P0_hist=dummy, pth=dummy)
if rename_hist:
ds = ds.rename({"sim": "hist"})
return ds
@map_groups(
af=[Grouper.PROP, "quantiles"],
hist_q=[Grouper.PROP, "quantiles"],
hist_q_raw=[Grouper.PROP, "quantiles"],
scaling=[Grouper.PROP],
P0_ref=[Grouper.PROP, Grouper.ADD_DIMS],
P0_hist=[Grouper.PROP, Grouper.ADD_DIMS],
pth=[Grouper.PROP, Grouper.ADD_DIMS],
)
def dqm_train(
ds: xr.Dataset,
*,
dim: str,
kind: str,
quantiles: np.ndarray,
adapt_freq_thresh: str | None = None,
jitter_under_thresh_value: str | None = None,
jitter_over_thresh_value: str | None = None,
jitter_over_thresh_upper_bnd: str | None = None,
max_tail_factor: float | None = None,
) -> xr.Dataset:
"""
Train step on one group.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
dim : str
The dimension along which to compute the quantiles.
kind : str
The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
quantiles : array-like
The quantiles to compute.
adapt_freq_thresh : str, optional
Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
Default is None, meaning that frequency adaptation is not performed.
jitter_under_thresh_value : str, optional
Threshold under which a uniform random noise is added to values, a quantity with units.
Default is None, meaning that jitter under thresh is not performed.
jitter_over_thresh_value : str, optional
Threshold above which a uniform random noise is added to values, a quantity with units.
Default is None, meaning that jitter over thresh is not performed.
jitter_over_thresh_upper_bnd : str, optional
Maximum possible value for the random noise, a quantity with units.
Default is None, meaning that jitter over thresh is not performed.
max_tail_factor: float, optional
If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
of the last quantile of hist (before the preprocessing steps, stored in hist_q_raw) are not adjusted.
We keep the input simulation with only the preprocessing steps instead.
If None, hist_q_raw output will just be a dummy variable.
Returns
-------
xr.Dataset
The dataset containing the adjustment factors, the quantiles over the training data, and the scaling factor.
Notes
-----
`jitter_over_thresh_value` and `jitter_over_thresh_upper_bnd` must be both be specified to
use `jitter_over_thresh`, or both be None (default) to skip it.
"""
sim_dim = Grouper.filter_dim(ds.hist, dim)
if max_tail_factor is not None:
# needed for max_tail_factor in dqm_adjust
hist_q_raw = nbu.quantile(ds.hist, quantiles, sim_dim)
ds = _preprocess_dataset(
ds,
dim,
adapt_freq_thresh,
jitter_under_thresh_value,
jitter_over_thresh_value,
jitter_over_thresh_upper_bnd,
)
# Ensures extra dimensions are only aggregated in datasets that have them
ref_dim = Grouper.filter_dim(ds.ref, dim)
# ds.hist might have been broadcasted in preprocess, so `sim_dim` must be re-computed
sim_dim = Grouper.filter_dim(ds.hist, dim)
refn = u.apply_correction(ds.ref, u.invert(ds.ref.mean(ref_dim), kind), kind)
histn = u.apply_correction(ds.hist, u.invert(ds.hist.mean(sim_dim), kind), kind)
ref_q = nbu.quantile(refn, quantiles, ref_dim)
hist_q = nbu.quantile(histn, quantiles, sim_dim)
if max_tail_factor is None:
# make a dummy variable to keep the same output structure
hist_q_raw = xr.full_like(hist_q, np.nan)
af = u.get_correction(hist_q, ref_q, kind)
mu_ref = ds.ref.mean(ref_dim)
mu_hist = ds.hist.mean(sim_dim)
scaling = u.get_correction(mu_hist, mu_ref, kind=kind)
return xr.Dataset(
data_vars={
"af": af,
"hist_q": hist_q,
"hist_q_raw": hist_q_raw,
"scaling": scaling,
"P0_ref": ds.P0_ref,
"P0_hist": ds.P0_hist,
"pth": ds.pth,
}
)
@map_groups(
af=[Grouper.PROP, "quantiles"],
hist_q=[Grouper.PROP, "quantiles"],
hist_q_raw=[Grouper.PROP, "quantiles"],
P0_ref=[Grouper.PROP, Grouper.ADD_DIMS],
P0_hist=[Grouper.PROP, Grouper.ADD_DIMS],
pth=[Grouper.PROP, Grouper.ADD_DIMS],
)
def eqm_train(
ds: xr.Dataset,
*,
dim: str,
kind: str,
quantiles: np.ndarray,
adapt_freq_thresh: str | None = None,
jitter_under_thresh_value: str | None = None,
jitter_over_thresh_value: str | None = None,
jitter_over_thresh_upper_bnd: str | None = None,
max_tail_factor: float | None = None,
) -> xr.Dataset:
"""
EQM: Train step on one group.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
dim : str
The dimension along which to compute the quantiles.
kind : str
The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
quantiles : array-like
The quantiles to compute.
adapt_freq_thresh : str, optional
Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
Default is None, meaning that frequency adaptation is not performed.
jitter_under_thresh_value : str, optional
Threshold under which a uniform random noise is added to values, a quantity with units.
Default is None, meaning that jitter under thresh is not performed.
jitter_over_thresh_upper_bnd : str, optional
Maximum possible value for the random noise, a quantity with units.
Default is None, meaning that jitter over thresh is not performed.
max_tail_factor: float, optional
If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
of the last quantile of hist (before the preprocessing steps, stored in hist_q_raw) are not adjusted.
We keep the input simulation with only the preprocessing steps instead.
If None, hist_q_raw output will just be a dummy variable.
Returns
-------
xr.Dataset
The dataset containing the adjustment factors and the quantiles over the training data.
Notes
-----
`jitter_over_thresh_value` and `jitter_over_thresh_upper_bnd` must be both be specified to
use `jitter_over_thresh`, or both be None (default) to skip it.
"""
sim_dim = Grouper.filter_dim(ds.hist, dim)
if max_tail_factor is not None:
# needed for max_tail_factor in dqm_adjust
hist_q_raw = nbu.quantile(ds.hist, quantiles, sim_dim)
ds = _preprocess_dataset(
ds,
dim,
adapt_freq_thresh,
jitter_under_thresh_value,
jitter_over_thresh_value,
jitter_over_thresh_upper_bnd,
)
# Ensures extra dimensions are only aggregated in datasets that have them
ref_dim = Grouper.filter_dim(ds.ref, dim)
# ds.hist might have been broadcasted in preprocess, so `sim_dim` must be re-computed
sim_dim = Grouper.filter_dim(ds.hist, dim)
ref_q = nbu.quantile(ds.ref, quantiles, ref_dim)
hist_q = nbu.quantile(ds.hist, quantiles, sim_dim)
if max_tail_factor is None:
# make a dummy variable to keep the same output structure
hist_q_raw = xr.full_like(hist_q, np.nan)
af = u.get_correction(hist_q, ref_q, kind)
return xr.Dataset(
data_vars={
"af": af,
"hist_q": hist_q,
"P0_ref": ds.P0_ref,
"P0_hist": ds.P0_hist,
"pth": ds.pth,
"hist_q_raw": hist_q_raw,
}
)
def _npdft_train(ref, hist, rots, quantiles, method, extrap, n_escore, standardize):
r"""
Npdf transform to correct a source `hist` into target `ref`.
Perform a rotation, bias correct `hist` into `ref` with QuantileDeltaMapping, and rotate back.
Do this iteratively over all rotations `rots` and conserve adjustment factors `af_q` in each iteration.
Notes
-----
This function expects numpy inputs. The input arrays `ref,hist` are expected to be 2-dimensional arrays with shape:
`(len(nfeature), len(time))`, where `nfeature` is the dimension which is mixed by the multivariate bias adjustment
(e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape
`(len(iterations), len(nfeature), len(nfeature))`.
"""
if standardize:
ref = (ref - np.nanmean(ref, axis=-1, keepdims=True)) / (np.nanstd(ref, axis=-1, keepdims=True))
hist = (hist - np.nanmean(hist, axis=-1, keepdims=True)) / (np.nanstd(hist, axis=-1, keepdims=True))
af_q = np.zeros((len(rots), ref.shape[0], len(quantiles)))
escores = np.zeros(len(rots)) * np.nan
if n_escore > 0:
ref_step, hist_step = (int(np.ceil(arr.shape[1] / n_escore)) for arr in [ref, hist])
for ii, _rot in enumerate(rots):
rot = _rot if ii == 0 else _rot @ rots[ii - 1].T
ref, hist = rot @ ref, rot @ hist
# loop over variables
for iv in range(ref.shape[0]):
ref_q, hist_q = nbu._quantile(ref[iv], quantiles), nbu._quantile(hist[iv], quantiles)
af_q[ii, iv] = ref_q - hist_q
af = u._interp_on_quantiles_1D(
u._rank_bn(hist[iv]),
quantiles,
af_q[ii, iv],
method=method,
extrap=extrap,
)
hist[iv] = hist[iv] + af
if n_escore > 0:
escores[ii] = nbu._escore(ref[:, ::ref_step], hist[:, ::hist_step])
hist = rots[-1].T @ hist # FIXME: This variable is unused.
return af_q, escores
[docs]
def mbcn_train(
ds: xr.Dataset,
rot_matrices: xr.DataArray,
pts_dims: Sequence[str],
quantiles: np.ndarray,
gw_idxs: xr.DataArray,
interp: str,
extrapolation: str,
n_escore: int,
) -> xr.Dataset:
"""
Npdf transform training.
Adjusting factors obtained for each rotation in the npdf transform and conserved to be applied in
the adjusting step in :py:func:`mcbn_adjust`.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
rot_matrices : xr.DataArray
The rotation matrices as a 3D array ('iterations', <pts_dims[0]>, <pts_dims[1]>), with shape (n_iter, <N>, <N>).
pts_dims : sequence of str
The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which
is the normal case when using :py:func:`xsdba.stack_variables`, and "multivar_prime".
quantiles : array-like
The quantiles to compute.
gw_idxs : xr.DataArray
Indices of the times in each windowed time group.
interp : str
The interpolation method to use.
extrapolation : str
The extrapolation method to use.
n_escore : int
Number of elements to include in the e_score test (0 for all, < 0 to skip).
Returns
-------
xr.Dataset
The dataset containing the adjustment factors and the quantiles over the training data
(only the npdf transform of mbcn).
"""
# unpack data
ref = ds.ref
hist = ds.hist
gr_dim = gw_idxs.attrs["group_dim"]
# npdf training core
af_q_l = []
escores_l = []
# loop over time blocks
for ib in range(gw_idxs[gr_dim].size):
# indices in a given time block
indices = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values
ind = indices[indices >= 0]
# npdft training : multiple rotations on standardized datasets
# keep track of adjustment factors in each rotation for later use
af_q, escores = xr.apply_ufunc(
_npdft_train,
ref[{"time": ind}],
hist[{"time": ind}],
rot_matrices,
quantiles,
input_core_dims=[
[pts_dims[0], "time"],
[pts_dims[0], "time"],
["iterations", pts_dims[1], pts_dims[0]],
["quantiles"],
],
output_core_dims=[
["iterations", pts_dims[1], "quantiles"],
["iterations"],
],
dask="parallelized",
output_dtypes=[hist.dtype, hist.dtype],
kwargs={
"method": interp,
"extrap": extrapolation,
"n_escore": n_escore,
"standardize": True,
},
vectorize=True,
)
af_q_l.append(af_q.expand_dims({gr_dim: [ib]}))
escores_l.append(escores.expand_dims({gr_dim: [ib]}))
af_q = xr.concat(af_q_l, dim=gr_dim)
escores = xr.concat(escores_l, dim=gr_dim)
out = xr.Dataset({"af_q": af_q, "escores": escores}).assign_coords({"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values})
return out
def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap):
"""
Npdf transform adjusting.
Adjusting factors `af_q` obtained in the training step are applied on the simulated data `sim` at each iterated
rotation, see :py:func:`_npdft_train`.
This function expects numpy inputs. `sim` can be a 2-d array with shape: `(len(nfeature), len(time))`, or
a 3-d array with shape: `(len(period), len(nfeature), len(time))`, allowing to adjust multiple climatological periods
all at once. `nfeature` is the dimension which is mixed by the multivariate bias adjustment
(e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape
`(len(iterations), len(nfeature), len(nfeature))`.
"""
# add dummy dim if period_dim absent to uniformize the function below
# This could be done at higher level, not sure where is best
if dummy_dim_added := (len(sim.shape) == 2):
sim = sim[:, np.newaxis, :]
# adjust npdft
for ii, _rot in enumerate(rots):
rot = _rot if ii == 0 else _rot @ rots[ii - 1].T
sim = np.einsum("ij,j...->i...", rot, sim)
# loop over variables
for iv in range(sim.shape[0]):
af = u._interp_on_quantiles_1D_multi(
u._rank_bn(sim[iv], axis=-1),
quantiles,
af_q[ii, iv],
method=method,
extrap=extrap,
)
sim[iv] = sim[iv] + af
rot = rots[-1].T
sim = np.einsum("ij,j...->i...", rot, sim)
if dummy_dim_added:
sim = sim[:, 0, :]
return sim
[docs]
def mbcn_adjust(
ref: xr.DataArray,
hist: xr.DataArray,
sim: xr.DataArray,
ds: xr.Dataset,
g_idxs: xr.DataArray,
gw_idxs: xr.DataArray,
pts_dims: tuple[str, str],
interp: str,
extrapolation: str,
base: Callable,
base_kws_vars: dict,
adj_kws: dict,
period_dim: str | None,
) -> xr.Dataset:
"""
Perform the adjustment portion MBCn multivariate bias correction technique.
The function :py:func:`mbcn_train` pre-computes the adjustment factors for each rotation
in the npdf portion of the MBCn algorithm. The rest of adjustment is performed here
in `mbcn_adjust``.
Parameters
----------
ref : xr.DataArray
training target.
hist : xr.DataArray
training data.
sim : xr.DataArray
data to adjust (stacked with multivariate dimension).
g_idxs : xr.DataArray
Indices of the times in each time group.
gw_idxs: xr.DataArray
Indices of the times in each windowed time group.
ds : xr.Dataset
Dataset variables:
rot_matrices : Rotation matrices used in the training step.
af_q : Adjustment factors obtained in the training step for the npdf transform
pts_dims : [str, str]
The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which
is the normal case when using :py:func:`xsdba.stack_variables`, and "multivar_prime".
interp : str
Interpolation method for the npdf transform (same as in the training step).
extrapolation : str
Extrapolation method for the npdf transform (same as in the training step).
base : BaseAdjustment
Bias-adjustment class used for the univariate bias correction.
base_kws_vars : Dict
Options for univariate training for the scenario that is reordered with the output of npdf transform.
The arguments are those expected by TrainAdjust classes along with
- kinds : Dict of correction kinds for each variable (e.g. {"pr":"*", "tasmax":"+"}).
adj_kws : Dict
Options for univariate adjust for the scenario that is reordered with the output of npdf transform.
period_dim : str, optional
Name of the period dimension used when stacking time periods of `sim` using :py:func:`xsdba.stack_periods`.
If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously.
This should be more performant, but also more memory intensive. Defaults to `None`: No optimization will be attempted.
Returns
-------
xr.Dataset
The adjusted data.
"""
# unpacking training parameters
rot_matrices = ds.rot_matrices
af_q = ds.af_q
quantiles = af_q.quantiles
gr_dim = gw_idxs.attrs["group_dim"]
win = gw_idxs.attrs["group"][1]
# this way of handling was letting open the possibility to perform
# interpolation for multiple periods in the simulation all at once
# in principle, avoiding redundancy. Need to test this on small data
# to confirm it works, and on big data to check performance.
dims = ["time"] if period_dim is None else [period_dim, "time"]
# mbcn core
scen_mbcn = xr.zeros_like(sim)
for ib in range(gw_idxs[gr_dim].size):
# indices in a given time block (with and without the window)
indices_gw = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values
ind_gw = indices_gw[indices_gw >= 0]
indices_g = g_idxs[{gr_dim: ib}].fillna(-1).astype(int).values
ind_g = indices_g[indices_g >= 0]
# 1. univariate adjustment of sim -> scen
# the kind may differ depending on the variables
scen_block = xr.zeros_like(sim[{"time": ind_gw}])
for iv, v in enumerate(sim[pts_dims[0]].values):
sl = {"time": ind_gw, pts_dims[0]: iv}
with set_options(extra_output=False):
ADJ = base.train(ref[sl], hist[sl], **base_kws_vars[v], skip_input_checks=True)
scen_block[{pts_dims[0]: iv}] = ADJ.adjust(sim[sl], **adj_kws, skip_input_checks=True)
# 2. npdft adjustment of sim
npdft_block = xr.apply_ufunc(
_npdft_adjust,
standardize(sim[{"time": ind_gw}].copy(), dim="time")[0],
af_q[{gr_dim: ib}],
rot_matrices,
quantiles,
input_core_dims=[
[pts_dims[0]] + dims,
["iterations", pts_dims[1], "quantiles"],
["iterations", pts_dims[1], pts_dims[0]],
["quantiles"],
],
output_core_dims=[
[pts_dims[0]] + dims,
],
dask="parallelized",
output_dtypes=[sim.dtype],
kwargs={"method": interp, "extrap": extrapolation},
vectorize=True,
)
# 3. reorder scen according to npdft results
reordered = reordering(ref=npdft_block, sim=scen_block)
if win > 1:
# keep central value of window (intersecting indices in gw_idxs and g_idxs)
scen_mbcn[{"time": ind_g}] = reordered[{"time": np.isin(ind_gw, ind_g)}]
else:
scen_mbcn[{"time": ind_g}] = reordered
return scen_mbcn.to_dataset(name="scen")
@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[])
def qm_adjust(
ds: xr.Dataset,
*,
group: Grouper,
interp: str,
extrapolation: str,
kind: str,
adapt_freq_thresh: str | None = None,
max_tail_factor: int | None = None,
) -> xr.Dataset:
"""
QM (DQM and EQM): Adjust step on one block.
Parameters
----------
ds : xr.Dataset
Dataset variables:
af : Adjustment factors
hist_q : Quantiles over the training data
sim : Data to adjust.
P0_ref (optional) : Proportion of zeroes in the reference
P0_hist (optional) : Proportion of zeroes in the historical period of the simulation
pth (optional) : The smallest value of `hist` that was not frequency-adjusted in the training.
group : Grouper
The grouper object.
interp : str
The interpolation method to use.
extrapolation : str
The extrapolation method to use.
kind : str
The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
adapt_freq_thresh : str, optional
Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
Default is None, meaning that frequency adaptation is not performed.
max_tail_factor: float, optional
If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
of the last quantile of hist (before the preprocessing steps) are not adjusted.
We keep the input simulation with only the preprocessing steps instead.
Returns
-------
xr.Dataset
The adjusted data.
"""
if adapt_freq_thresh:
ds["sim"] = _adapt_freq_preprocess(
ds[["sim", "P0_ref", "P0_hist", "pth"]],
adapt_freq_thresh,
group=Grouper(group.name),
dim=None,
).sim
# mask no bias adjustment, when sim is larger than n times the largest quantile in hist (without adapt freq)
if max_tail_factor is not None:
adaptedsim = ds["sim"].copy()
last_quantile = ds["hist_q_raw"].isel({"quantiles": -1}).drop_vars("quantiles")
# make last_quantile dim fit adaptedsim dim
last_quantile = u.broadcast(
last_quantile,
adaptedsim,
group=group,
interp=interp if group.prop != "dayofyear" else "nearest",
)
mask = adaptedsim > max_tail_factor * last_quantile
af = u.interp_on_quantiles(
ds.sim,
ds.hist_q,
ds.af,
group=group,
method=interp,
extrapolation=extrapolation,
)
scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen")
# apply max_tail_factor mask
if max_tail_factor is not None:
scen = scen.where(~mask, adaptedsim)
out = scen.to_dataset()
return out
@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], trend=[])
def dqm_adjust(
ds: xr.Dataset,
*,
group: Grouper,
interp: str,
kind: str,
extrapolation: str,
detrend: int | PolyDetrend,
adapt_freq_thresh: str | None = None,
max_tail_factor: int | None = None,
) -> xr.Dataset:
"""
DQM adjustment on one block.
Parameters
----------
ds : xr.Dataset
Dataset variables:
scaling : Scaling factor between ref and hist
af : Adjustment factors
hist_q : Quantiles over the training data
sim : Data to adjust
P0_ref (optional) : Proportion of zeroes in the reference
P0_hist (optional) : Proportion of zeroes in the historical period of the simulation
pth (optional) : The smallest value of `hist` that was not frequency-adjusted in the training.
group : Grouper
The grouper object.
interp : str
The interpolation method to use.
kind : str
The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
extrapolation : str
The extrapolation method to use.
detrend : int | PolyDetrend
The degree of the polynomial detrending to apply. If 0, no detrending is applied.
adapt_freq_thresh : str, optional
Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
Default is None, meaning that frequency adaptation is not performed.
max_tail_factor: float, optional
If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
of the last quantile of hist (before the preprocessing steps) are not adjusted.
We keep the input simulation with only the preprocessing steps instead.
Returns
-------
xr.Dataset
The adjusted data and the trend.
"""
if adapt_freq_thresh:
ds["sim"] = _adapt_freq_preprocess(
ds[["sim", "P0_ref", "P0_hist", "pth"]],
adapt_freq_thresh,
group=Grouper(group.name),
dim=None,
).sim
# mask no bias adjustment, when sim is larger than n times the largest quantile in hist (without adapt freq)
if max_tail_factor is not None:
adaptedsim = ds["sim"].copy()
last_quantile = ds["hist_q_raw"].isel({"quantiles": -1}).drop_vars("quantiles")
# make last_quantile dim fit adaptedsim dim
last_quantile = u.broadcast(
last_quantile,
adaptedsim,
group=group,
interp=interp if group.prop != "dayofyear" else "nearest",
)
mask = adaptedsim > max_tail_factor * last_quantile
scaled_sim = u.apply_correction(
ds.sim,
u.broadcast(
ds.scaling,
ds.sim,
group=group,
interp=interp if group.prop != "dayofyear" else "nearest",
),
kind,
).assign_attrs({"units": ds.sim.units})
if isinstance(detrend, int):
detrending = PolyDetrend(degree=detrend, kind=kind, group=group)
else:
detrending = detrend
detrending = detrending.fit(scaled_sim)
ds["sim"] = detrending.detrend(scaled_sim)
scen = qm_adjust.func(
ds,
group=group,
interp=interp,
extrapolation=extrapolation,
kind=kind,
).scen
scen = detrending.retrend(scen)
# apply max_tail_factor mask
if max_tail_factor is not None:
scen = scen.where(~mask, adaptedsim)
out = xr.Dataset({"scen": scen, "trend": detrending.ds.trend})
return out
@map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], sim_q=[])
def qdm_adjust(
ds: xr.Dataset,
*,
group: Grouper,
interp: str,
extrapolation: str,
kind: str,
adapt_freq_thresh: str | None = None,
rank_window: bool | None = None,
max_tail_factor: int | None = None,
) -> xr.Dataset:
"""
QDM adjustment on one block.
Parameters
----------
ds : xr.Dataset
Dataset variables:
af : Adjustment factors
hist_q : Quantiles over the training data
sim : Data to adjust.
group : Grouper
The grouper object.
interp : str
The interpolation method to use.
kind : str
The kind of correction to compute. See :py:func:`xsdba.utils.get_correction`.
extrapolation : str
The extrapolation method to use.
detrend : int | PolyDetrend
The degree of the polynomial detrending to apply. If 0, no detrending is applied.
adapt_freq_thresh : str, optional
Threshold for frequency adaptation. See :py:class:`xsdba.processing.adapt_freq` for details.
Default is None, meaning that frequency adaptation is not performed.
rank_window : bool, optional
Whether to rank simulated values over the full grouping window.
Effectively, the default is `False` which preserves legacy behavior.
If `False`, ranks are computed within exact groups, e.g., a specific day of year.
In `xsdba>=0.8`, this option will be deprecated in favour of honoring the grouping window (equivalent to `True`).
max_tail_factor: float, optional
If not None, values to adjust (after preprossing steps) that are above max_tail_factor * the value
of the last quantile of hist (before the preprocessing steps) are not adjusted.
We keep the input simulation with only the preprocessing steps instead.
Returns
-------
xr.Dataset
The adjusted data.
Warns
-----
DeprecationWarning
If rank_window is None and group size is larger than one, a warning indicates that in `xsdba>=0.8`
the current behavior will change.
"""
if adapt_freq_thresh:
ds["sim"] = _adapt_freq_preprocess(
ds[["sim", "P0_ref", "P0_hist", "pth"]],
adapt_freq_thresh,
group=Grouper(group.name),
dim=None,
).sim
# mask no bias adjustment, when sim is larger than n times the largest quantile in hist (without adapt freq)
if max_tail_factor is not None:
adaptedsim = ds["sim"].copy()
last_quantile = ds["hist_q_raw"].isel({"quantiles": -1}).drop_vars("quantiles")
# make last_quantile dim fit adaptedsim dim
last_quantile = u.broadcast(
last_quantile,
adaptedsim,
group=group,
interp=interp if group.prop != "dayofyear" else "nearest",
)
mask = adaptedsim > max_tail_factor * last_quantile
if rank_window is None:
rank_window = False
if group.window > 1:
warnings.warn(
"QDM method can now perform the adjustment step by expanding the time dimension "
"with the same window as used in the training. This can already be used by setting "
"`rank_window = True`. This will be the only possible behaviour in `xsdba>=0.8`. "
"The current behaviour is obtained by setting `rank_window = False` and will be "
"deprecated in `xsdba>=0.8`. It will still be possible to use the old behaviour by "
"monkeypatching the group argument in the QDM class between the training and "
"adjustment, though this behaviour is not recommended.",
category=DeprecationWarning,
stacklevel=2,
)
sim_q = group.apply(u.rank, ds.sim, main_only=not rank_window, pct=True)
af = u.interp_on_quantiles(
sim_q,
ds.quantiles,
ds.af,
group=group,
method=interp,
extrapolation=extrapolation,
)
scen = u.apply_correction(ds.sim, af, kind)
# apply max_tail_factor mask
if max_tail_factor is not None:
scen = scen.where(~mask, adaptedsim)
return xr.Dataset({"scen": scen, "sim_q": sim_q})
@map_blocks(
reduces=[Grouper.ADD_DIMS, Grouper.DIM],
af=[Grouper.PROP],
hist_thresh=[Grouper.PROP],
)
def loci_train(ds: xr.Dataset, *, group, thresh) -> xr.Dataset:
"""
LOCI: Train on one block.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
"""
s_thresh = group.apply(u.map_cdf, ds.rename(hist="x", ref="y"), y_value=thresh).isel(x=0)
sth = u.broadcast(s_thresh, ds.hist, group=group)
ws = xr.where(ds.hist >= sth, ds.hist, np.nan)
wo = xr.where(ds.ref >= thresh, ds.ref, np.nan)
ms = group.apply("mean", ws, skipna=True)
mo = group.apply("mean", wo, skipna=True)
# Adjustment factor
af = u.get_correction(ms - s_thresh, mo - thresh, u.MULTIPLICATIVE)
return xr.Dataset({"af": af, "hist_thresh": s_thresh})
@map_blocks(reduces=[Grouper.PROP], scen=[])
def loci_adjust(ds: xr.Dataset, *, group, thresh, interp) -> xr.Dataset:
"""
LOCI: Adjust on one block.
Parameters
----------
ds : xr.Dataset
Dataset variables:
hist_thresh : Hist's equivalent thresh from ref
sim : Data to adjust
"""
sth = u.broadcast(ds.hist_thresh, ds.sim, group=group, interp=interp)
factor = u.broadcast(ds.af, ds.sim, group=group, interp=interp)
with xr.set_options(keep_attrs=True):
scen: xr.DataArray = (factor * (ds.sim - sth) + thresh).clip(min=0).rename("scen")
out = scen.to_dataset()
return out
@map_groups(af=[Grouper.PROP])
def scaling_train(ds: xr.Dataset, *, dim, kind) -> xr.Dataset:
"""
Scaling: Train on one group.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
"""
ref_dim = Grouper.filter_dim(ds.ref, dim)
sim_dim = Grouper.filter_dim(ds.hist, dim)
mhist = ds.hist.mean(sim_dim)
mref = ds.ref.mean(ref_dim)
af: xr.DataArray = u.get_correction(mhist, mref, kind).rename("af")
out = af.to_dataset()
return out
@map_blocks(reduces=[Grouper.PROP], scen=[])
def scaling_adjust(ds: xr.Dataset, *, group, interp, kind) -> xr.Dataset:
"""
Scaling: Adjust on one block.
Parameters
----------
ds : xr.Dataset
Dataset variables:
af : Adjustment factors.
sim : Data to adjust.
"""
af = u.broadcast(ds.af, ds.sim, group=group, interp=interp)
scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen")
out = scen.to_dataset()
return out
def _fit_on_cluster(data, thresh, cluster_thresh, dist):
"""Extract clusters on 1D data and fit "dist" on the maximums."""
_, _, _, maximums = u.get_clusters_1d(data, thresh, cluster_thresh)
params = list(_fitfunc_1d(maximums - thresh, dist=dist, floc=0, nparams=3, method="ML"))
# We forced 0, put back thresh.
params[-2] = thresh
return params
def _extremes_train_1d(ref, hist, ref_params, cluster_thresh, *, q_thresh, dist, N):
"""Train for method ExtremeValues, only for 1D input along time."""
# Fast-track, do nothing for all-nan slices
if all(np.isnan(ref)) or all(np.isnan(hist)):
return np.full(N, np.nan), np.full(N, np.nan), np.nan
# Find quantile q_thresh
thresh = (np.nanquantile(ref[ref >= cluster_thresh], q_thresh) + np.nanquantile(hist[hist >= cluster_thresh], q_thresh)) / 2
# Fit genpareto on cluster maximums on ref (if needed) and hist.
if np.isnan(ref_params).all():
ref_params = _fit_on_cluster(ref, thresh, cluster_thresh, dist)
hist_params = _fit_on_cluster(hist, thresh, cluster_thresh, dist)
# Find probabilities of extremes according to fitted dist
Px_ref = dist.cdf(ref[ref >= thresh], *ref_params)
hist = hist[hist >= thresh]
Px_hist = dist.cdf(hist, *hist_params)
# Find common probabilities range.
Pmax = min(Px_ref.max(), Px_hist.max())
Pmin = max(Px_ref.min(), Px_hist.min())
Pcommon = (Px_hist <= Pmax) & (Px_hist >= Pmin)
Px_hist = Px_hist[Pcommon]
# Find values of hist extremes if they followed ref's distribution.
hist_in_ref = dist.ppf(Px_hist, *ref_params)
# Adjustment factors, unsorted
af = hist_in_ref / hist[Pcommon]
# sort them in Px order, and pad to have N values.
order = np.argsort(Px_hist)
if af.size > N:
raise ValueError(
"The number of precipitations part of a cluster is larger than `q_thresh`, which "
"likely indicates that `cluster_thresh` is too small for `ref` and/or `hist`, i.e."
"`cluster_thresh` is still in the bulk of the distribution."
)
px_hist = np.pad(Px_hist[order], ((0, N - af.size),), constant_values=np.nan)
af = np.pad(af[order], ((0, N - af.size),), constant_values=np.nan)
return px_hist, af, thresh
@map_blocks(reduces=["time"], px_hist=["quantiles"], af=["quantiles"], thresh=[Grouper.PROP])
def extremes_train(
ds: xr.Dataset,
*,
group: Grouper,
q_thresh: float,
dist,
quantiles: np.ndarray,
) -> xr.Dataset:
"""
Train extremes for a given variable series.
Parameters
----------
ds : xr.Dataset
Dataset containing the reference and historical data, and cluster thresholds.
group : Grouper
The grouper object.
q_thresh : float
The quantile threshold to use.
dist : Any
The distribution to fit.
quantiles : array-like
The quantiles to compute.
Returns
-------
xr.Dataset
The dataset containing the quantiles, the adjustment factors, and the threshold.
"""
px_hist, af, thresh = xr.apply_ufunc(
_extremes_train_1d,
ds.ref,
ds.hist,
ds.ref_params or np.nan,
ds.cluster_thresh,
input_core_dims=[("time",), ("time",), (), ()],
output_core_dims=[("quantiles",), ("quantiles",), ()],
vectorize=True,
kwargs={
"q_thresh": q_thresh,
"dist": dist,
"N": len(quantiles),
},
)
# Outputs of map_blocks must have dimensions.
if not isinstance(thresh, xr.DataArray):
thresh = xr.DataArray(thresh)
thresh = thresh.expand_dims(group=[1])
return xr.Dataset(
{"px_hist": px_hist, "af": af, "thresh": thresh},
coords={"quantiles": quantiles},
)
def _fit_cluster_and_cdf(data, thresh, cluster_thresh, dist):
"""Fit 1D cluster maximums and immediately compute CDF."""
fut_params = _fit_on_cluster(data, thresh, cluster_thresh, dist)
return dist.cdf(data, *fut_params)
@map_blocks(reduces=["quantiles", Grouper.PROP], scen=[])
def extremes_adjust(
ds: xr.Dataset,
*,
group: Grouper,
frac: float,
power: float,
dist,
interp: str,
extrapolation: str,
) -> xr.Dataset:
"""
Adjust extremes to reflect many distribution factors.
Parameters
----------
ds : xr.Dataset
Dataset containing the reference and historical data, and cluster thresholds.
group : Grouper
The grouper object.
frac : float
The fraction of the transition function.
power : float
The power of the transition function.
dist : Any
The distribution to fit.
interp : str
The interpolation method to use.
extrapolation : str
The extrapolation method to use.
Returns
-------
xr.Dataset
The dataset containing the adjusted data.
"""
# Find probabilities of extremes of fut according to its own cluster-fitted dist.
px_fut = xr.apply_ufunc(
_fit_cluster_and_cdf,
ds.sim,
ds.thresh,
ds.cluster_thresh,
input_core_dims=[["time"], [], []],
output_core_dims=[["time"]],
kwargs={"dist": dist},
vectorize=True,
)
# Find factors by interpolating from hist probs to fut probs. apply them.
af = u.interp_on_quantiles(px_fut, ds.px_hist, ds.af, method=interp, extrapolation=extrapolation)
scen = u.apply_correction(ds.sim, af, "*")
# Smooth transition function between simulation and scenario. Values below ds.thresh are kept unchanged.
transition = (((ds.sim - ds.thresh).clip(0, None) / ((ds.sim.max("time")) - ds.thresh)) / frac) ** power
transition = transition.clip(0, 1)
adjusted: xr.DataArray = (transition * scen) + ((1 - transition) * ds.scen)
out = adjusted.rename("scen").squeeze("group", drop=True).to_dataset()
return out
def _otc_adjust(
X: np.ndarray,
Y: np.ndarray,
bin_width: dict | float | np.ndarray | None = None,
bin_origin: dict | float | np.ndarray | None = None,
num_iter_max: int | None = 100_000_000,
jitter_inside_bins: bool = True,
normalization: str | None = "max_distance",
):
"""
Optimal Transport Correction of the bias of X with respect to Y.
Parameters
----------
X : np.ndarray
Historical data to be corrected.
Y : np.ndarray
Bias correction reference, target of optimal transport.
bin_width : dict or float or np.ndarray, optional
Bin widths for specified dimensions.
bin_origin : dict or float or np.ndarray, optional
Bin origins for specified dimensions.
num_iter_max : int, optional
Maximum number of iterations used in the Earth-Mover_Distance (EMD) algorithm.
jitter_inside_bins : bool
If `False`, output points are located at the center of their bin.
If `True`, a random location is picked uniformly inside their bin. Default is `True`.
normalization : {'standardize', 'max_distance', 'max_value'}, optional
Per-variable transformation applied before the distances are calculated
in the optimal transport.
Returns
-------
np.ndarray
Adjusted data.
References
----------
:cite:cts:`robin_2021`
"""
# nans are removed and put back in place at the end
X_og = X.copy()
mask = (~np.isnan(X)).all(axis=1)
X = X[mask]
Y = Y[(~np.isnan(Y)).all(axis=1)]
# Initialize parameters
if bin_width is None:
bin_width = u.bin_width_estimator([Y, X])
elif isinstance(bin_width, dict):
_bin_width = u.bin_width_estimator([Y, X])
for k, v in bin_width.items():
_bin_width[k] = v
bin_width = _bin_width
elif isinstance(bin_width, float | int):
bin_width = np.ones(X.shape[1]) * bin_width
if bin_origin is None:
bin_origin = np.zeros(X.shape[1])
elif isinstance(bin_origin, dict):
_bin_origin = np.zeros(X.shape[1])
if bin_origin is not None:
for v, k in bin_origin.items():
_bin_origin[v] = k
bin_origin = _bin_origin
elif isinstance(bin_origin, float | int):
bin_origin = np.ones(X.shape[1]) * bin_origin
num_iter_max = 100_000_000 if num_iter_max is None else num_iter_max
# Get the bin positions and frequencies of X and Y, and for all Xs the bin to which they belong
gridX, muX, binX = u.histogram(X, bin_width, bin_origin)
gridY, muY, _ = u.histogram(Y, bin_width, bin_origin)
# Compute the optimal transportation plan
plan = u.optimal_transport(gridX, gridY, muX, muY, num_iter_max, normalization)
gridX = np.floor((gridX - bin_origin) / bin_width) # FIXME: This variable is unused.
gridY = np.floor((gridY - bin_origin) / bin_width)
# regroup the indices of all the points belonging to a same bin
binX_sort = np.lexsort(binX[:, ::-1].T)
sorted_bins = binX[binX_sort]
_, binX_start, binX_count = np.unique(sorted_bins, return_index=True, return_counts=True, axis=0)
binX_start_sort = np.sort(binX_start)
binX_groups = np.split(binX_sort, binX_start_sort[1:])
out = np.empty(X.shape)
rng = np.random.default_rng()
# The plan row corresponding to a source bin indicates its probabilities to be transported to every target bin
for i, binX_group in enumerate(binX_groups):
# Pick as much target bins for this source bin as there are points in the source bin
choice = rng.choice(range(muY.size), p=plan[i, :], size=binX_count[i])
out[binX_group] = (gridY[choice] + 1 / 2) * bin_width + bin_origin
if jitter_inside_bins:
out += np.random.uniform(low=-bin_width / 2, high=bin_width / 2, size=out.shape)
# reintroduce nans
Z = X_og
Z[mask] = out
Z[~mask] = np.nan
return Z
@map_groups(scen=[Grouper.DIM])
def otc_adjust(
ds: xr.Dataset,
dim: list,
pts_dim: str,
bin_width: dict | float | None = None,
bin_origin: dict | float | None = None,
num_iter_max: int | None = 100_000_000,
jitter_inside_bins: bool = True,
adapt_freq_thresh: dict | None = None,
normalization: str | None = "max_distance",
):
"""
Optimal Transport Correction of the bias of `hist` with respect to `ref`.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
dim : list
The dimensions defining the distribution on which optimal transport is performed.
pts_dim : str
The dimension defining the multivariate components of the distribution.
bin_width : dict or float, optional
Bin widths for specified dimensions.
bin_origin : dict or float, optional
Bin origins for specified dimensions.
num_iter_max : int, optional
Maximum number of iterations used in the Earth Mover Distance (EMD) algorithm.
jitter_inside_bins : bool
If `False`, output points are located at the center of their bin.
If `True`, a random location is picked uniformly inside their bin. Default is `True`.
adapt_freq_thresh : dict, optional
Threshold for frequency adaptation per variable.
normalization : {'standardize', 'max_distance', 'max_value'}, optional
Per-variable transformation applied before the distances are calculated
in the optimal transport.
Returns
-------
xr.Dataset
Adjusted data.
"""
ref = ds.ref
hist = ds.hist
if adapt_freq_thresh is not None:
for var, thresh in adapt_freq_thresh.items():
ds0 = xr.Dataset({"ref": ref.sel({pts_dim: var}), "sim": hist.sel({pts_dim: var})})
hist.loc[{pts_dim: var}] = _preprocess_dataset(ds0, adapt_freq_thresh=thresh).sim
ref_dim = Grouper.filter_dim(ref, dim)
ref_map = {d: f"ref_{d}" for d in ref_dim}
ref = ref.rename(ref_map).stack(dim_ref=ref_map.values()).dropna(dim="dim_ref")
sim_dim = Grouper.filter_dim(hist, dim)
hist = hist.stack(dim_hist=sim_dim).dropna(dim="dim_hist")
if isinstance(bin_width, dict):
bin_width = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_width.items()}
if isinstance(bin_origin, dict):
bin_origin = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_origin.items()}
scen = xr.apply_ufunc(
_otc_adjust,
hist,
ref,
kwargs={
"bin_width": bin_width,
"bin_origin": bin_origin,
"num_iter_max": num_iter_max,
"jitter_inside_bins": jitter_inside_bins,
"normalization": normalization,
},
input_core_dims=[["dim_hist", pts_dim], ["dim_ref", pts_dim]],
output_core_dims=[["dim_hist", pts_dim]],
keep_attrs=True,
vectorize=True,
)
scen = scen.unstack().rename("scen")
return scen.to_dataset()
def _dotc_adjust(
X1: np.ndarray,
Y0: np.ndarray,
X0: np.ndarray,
bin_width: dict | float | None = None,
bin_origin: dict | float | None = None,
num_iter_max: int | None = 100_000_000,
cov_factor: str | None = "std",
jitter_inside_bins: bool = True,
kind: dict | None = None,
normalization: str | None = "max_distance",
):
"""
Dynamical Optimal Transport Correction of the bias of X with respect to Y.
Parameters
----------
X1 : np.ndarray
Simulation data to adjust.
Y0 : np.ndarray
Bias correction reference.
X0 : np.ndarray
Historical simulation data.
bin_width : dict or float, optional
Bin widths for specified dimensions.
bin_origin : dict or float, optional
Bin origins for specified dimensions.
num_iter_max : int, optional
Maximum number of iterations used in the Earth Mover Distance (EMD) algorithm.
cov_factor : str, optional
Rescaling factor.
jitter_inside_bins : bool
If `False`, output points are located at the center of their bin.
If `True`, a random location is picked uniformly inside their bin. Default is `True`.
kind : dict, optional
Keys are variable names and values are adjustment kinds, either additive or multiplicative.
Unspecified dimensions are treated as "+".
normalization : {'standardize', 'max_distance', 'max_value'}, optional
Per-variable transformation applied before the distances are calculated
in the optimal transport.
Returns
-------
np.ndarray
Adjusted data.
References
----------
:cite:cts:`robin_2021`
"""
# nans are removed and put back in place at the end
X1_og = X1.copy()
mask = ~np.isnan(X1).any(axis=1)
X1 = X1[mask]
X0 = X0[~np.isnan(X0).any(axis=1)]
Y0 = Y0[~np.isnan(Y0).any(axis=1)]
# Initialize parameters
if isinstance(bin_width, dict):
_bin_width = u.bin_width_estimator([Y0, X0, X1])
for v, k in bin_width.items():
_bin_width[v] = k
bin_width = _bin_width
elif isinstance(bin_width, float | int):
bin_width = np.ones(X0.shape[1]) * bin_width
if isinstance(bin_origin, dict):
_bin_origin = np.zeros(X0.shape[1])
for v, k in bin_origin.items():
_bin_origin[v] = k
bin_origin = _bin_origin
elif isinstance(bin_origin, float | int):
bin_origin = np.ones(X0.shape[1]) * bin_origin
# Map ref to hist
yX0 = _otc_adjust(
Y0,
X0,
bin_width=bin_width,
bin_origin=bin_origin,
num_iter_max=num_iter_max,
jitter_inside_bins=False,
normalization=normalization,
)
# Map hist to sim
yX1 = _otc_adjust(
yX0,
X1,
bin_width=bin_width,
bin_origin=bin_origin,
num_iter_max=num_iter_max,
jitter_inside_bins=False,
normalization=normalization,
)
# Temporal evolution
motion = np.empty(yX0.shape)
for j in range(yX0.shape[1]):
if kind is not None and j in kind.keys() and kind[j] == "*":
motion[:, j] = yX1[:, j] / yX0[:, j]
else:
motion[:, j] = yX1[:, j] - yX0[:, j]
# Apply a variance dependent rescaling factor
if cov_factor == "cholesky":
fact0 = u.eps_cholesky(np.cov(Y0, rowvar=False))
fact1 = u.eps_cholesky(np.cov(X0, rowvar=False))
motion = (fact0 @ np.linalg.inv(fact1) @ motion.T).T
elif cov_factor == "std":
fact0 = np.std(Y0, axis=0)
fact1 = np.std(X0, axis=0)
motion = motion @ np.diag(fact0 / fact1)
# Apply the evolution to ref
Y1 = np.empty(yX0.shape)
for j in range(yX0.shape[1]):
if kind is not None and j in kind.keys() and kind[j] == "*":
Y1[:, j] = Y0[:, j] * motion[:, j]
else:
Y1[:, j] = Y0[:, j] + motion[:, j]
# Map sim to the evolution of ref
out = _otc_adjust(
X1,
Y1,
bin_width=bin_width,
bin_origin=bin_origin,
num_iter_max=num_iter_max,
jitter_inside_bins=jitter_inside_bins,
normalization=normalization,
)
# reintroduce nans
Z1 = X1_og
Z1[mask] = out
Z1[~mask] = np.nan
return Z1
@map_groups(scen=[Grouper.DIM])
def dotc_adjust(
ds: xr.Dataset,
dim: list,
pts_dim: str,
bin_width: dict | float | None = None,
bin_origin: dict | float | None = None,
num_iter_max: int | None = 100_000_000,
cov_factor: str | None = "std",
jitter_inside_bins: bool = True,
kind: dict | None = None,
adapt_freq_thresh: dict | None = None,
normalization: str | None = "max_distance",
):
"""
Dynamical Optimal Transport Correction of the bias of X with respect to Y.
Parameters
----------
ds : xr.Dataset
Dataset variables:
ref : training target
hist : training data
sim : simulated data
dim : list
The dimensions defining the distribution on which optimal transport is performed.
pts_dim : str
The dimension defining the multivariate components of the distribution.
bin_width : dict or float, optional
Bin widths for specified dimensions.
bin_origin : dict or float, optional
Bin origins for specified dimensions.
num_iter_max : int, optional
Maximum number of iterations used in the Earth Mover Distance (EMD) algorithm.
cov_factor : str, optional
Rescaling factor.
jitter_inside_bins : bool
If `False`, output points are located at the center of their bin.
If `True`, a random location is picked uniformly inside their bin. Default is `True`.
kind : dict, optional
Keys are variable names and values are adjustment kinds, either additive or multiplicative.
Unspecified dimensions are treated as "+".
adapt_freq_thresh : dict, optional
Threshold for frequency adaptation per variable.
normalization : {'standardize', 'max_distance', 'max_value'}, optional
Per-variable transformation applied before the distances are calculated
in the optimal transport.
Returns
-------
xr.Dataset
Adjusted data.
"""
hist = ds.hist
sim = ds.sim
ref = ds.ref
if adapt_freq_thresh is not None:
for var, thresh in adapt_freq_thresh.items():
if thresh is not None:
ds0 = xr.Dataset({"ref": ref.sel({pts_dim: var}), "sim": hist.sel({pts_dim: var})})
# add the `dP0, P0_ref, P0_hist, pth` datasets
ds0 = _preprocess_dataset(ds0, dim=dim, adapt_freq_thresh=thresh)
hist.loc[{pts_dim: var}] = ds0.sim
ds0["sim"] = sim.loc[{pts_dim: var}]
# remove the `ref` dataset since we already have `P0_ref` and other datasets
ds0 = ds0.drop("ref")
sim.loc[{pts_dim: var}] = _preprocess_dataset(ds0, dim=dim, adapt_freq_thresh=thresh).sim
# Drop data added by map_blocks and prepare for apply_ufunc
sim_dim = Grouper.filter_dim(sim, dim)
hist_map = {d: f"hist_{d}" for d in sim_dim}
hist = hist.rename(hist_map).stack(dim_hist=hist_map.values())
ref_dim = Grouper.filter_dim(ref, dim)
ref_map = {d: f"ref_{d}" for d in ref_dim}
ref = ref.rename(ref_map).stack(dim_ref=ref_map.values())
sim = sim.stack(dim_sim=sim_dim)
if kind is not None:
kind = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in kind.items()}
if isinstance(bin_width, dict):
bin_width = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_width.items()}
if isinstance(bin_origin, dict):
bin_origin = {np.where(ref[pts_dim].values == var)[0][0]: op for var, op in bin_origin.items()}
scen = xr.apply_ufunc(
_dotc_adjust,
sim,
ref,
hist,
kwargs={
"bin_width": bin_width,
"bin_origin": bin_origin,
"num_iter_max": num_iter_max,
"cov_factor": cov_factor,
"jitter_inside_bins": jitter_inside_bins,
"kind": kind,
"normalization": normalization,
},
input_core_dims=[
["dim_sim", pts_dim],
["dim_ref", pts_dim],
["dim_hist", pts_dim],
],
output_core_dims=[["dim_sim", pts_dim]],
keep_attrs=True,
vectorize=True,
)
scen = scen.unstack().rename("scen")
return scen.to_dataset()