Source code for xsdba.nbutils

"""
Numba-accelerated Utilities
===========================
"""

from __future__ import annotations
import warnings
from collections.abc import Hashable, Sequence

import numpy as np
from numba import boolean, float32, float64, guvectorize, njit
from xarray import DataArray, apply_ufunc
from xarray.core import utils


try:
    from fastnanquantile.xrcompat import xr_apply_nanquantile

    USE_FASTNANQUANTILE = True
except ImportError:
    USE_FASTNANQUANTILE = False


@njit(
    fastmath={"arcp", "contract", "reassoc", "nsz", "afn"},
    nogil=True,
    cache=True,
)
def _get_indexes(arr: np.array, virtual_indexes: np.array, valid_values_count: np.array) -> tuple[np.array, np.array]:
    """
    Get the valid indexes of arr neighbouring virtual_indexes.

    Parameters
    ----------
    arr : array-like
    virtual_indexes : array-like
    valid_values_count : array-like

    Returns
    -------
    array-like, array-like
        A tuple of virtual_indexes neighbouring indexes (previous and next).

    Notes
    -----
    This is a companion function to linear interpolation of quantiles.
    """
    previous_indexes = np.asarray(np.floor(virtual_indexes))
    next_indexes = np.asarray(previous_indexes + 1)
    indexes_above_bounds = virtual_indexes >= valid_values_count - 1
    # When indexes is above max index, take the max value of the array
    if indexes_above_bounds.any():
        previous_indexes[indexes_above_bounds] = -1
        next_indexes[indexes_above_bounds] = -1
    # When indexes is below min index, take the min value of the array
    indexes_below_bounds = virtual_indexes < 0
    if indexes_below_bounds.any():
        previous_indexes[indexes_below_bounds] = 0
        next_indexes[indexes_below_bounds] = 0
    if (arr.dtype is np.dtype(np.float64)) or (arr.dtype is np.dtype(np.float32)):
        # After the sort, slices having NaNs will have for last element a NaN
        virtual_indexes_nans = np.isnan(virtual_indexes)
        if virtual_indexes_nans.any():
            previous_indexes[virtual_indexes_nans] = -1
            next_indexes[virtual_indexes_nans] = -1
    previous_indexes = previous_indexes.astype(np.intp)
    next_indexes = next_indexes.astype(np.intp)
    return previous_indexes, next_indexes


@njit(
    fastmath={"arcp", "contract", "reassoc", "nsz", "afn"},
    nogil=True,
    cache=True,
)
def _linear_interpolation(
    left: np.ndarray,
    right: np.ndarray,
    gamma: np.ndarray,
) -> np.ndarray:
    """
    Compute the linear interpolation weighted by gamma on each point of two same shape arrays.

    Parameters
    ----------
    left : array_like
        Left bound.
    right : array_like
        Right bound.
    gamma : array_like
        The interpolation weight.

    Returns
    -------
    array_like

    Notes
    -----
    This is a companion function for `_nan_quantile_1d`
    """
    diff_b_a = np.subtract(right, left)
    lerp_interpolation = np.asarray(np.add(left, diff_b_a * gamma))
    ind = gamma >= 0.5
    lerp_interpolation[ind] = right[ind] - diff_b_a[ind] * (1 - gamma[ind])
    return lerp_interpolation


@njit(
    fastmath={"arcp", "contract", "reassoc", "nsz", "afn"},
    nogil=True,
    cache=True,
)
def _nan_quantile_1d(
    arr: np.ndarray,
    quantiles: np.ndarray,
    alpha: float = 1.0,
    beta: float = 1.0,
) -> float | np.ndarray:
    """
    Get the quantiles of the 1-dimensional array.

    A  linear interpolation is performed using alpha and beta.

    Notes
    -----
    By default, `alpha == beta == 1` which performs the 7th method of :cite:t:`hyndman_sample_1996`.
    with `alpha == beta == 1/3` we get the 8th method. alpha == beta == 1 reproduces the behaviour of `np.nanquantile`.
    """
    # We need at least two values to do an interpolation
    valid_values_count = (~np.isnan(arr)).sum()

    # Computation of indexes
    virtual_indexes = valid_values_count * quantiles + (alpha + quantiles * (1 - alpha - beta)) - 1
    virtual_indexes = np.asarray(virtual_indexes)
    previous_indexes, next_indexes = _get_indexes(arr, virtual_indexes, valid_values_count)
    # Sorting
    arr.sort()

    previous = arr[previous_indexes]
    next_elements = arr[next_indexes]

    # Linear interpolation
    gamma = np.asarray(virtual_indexes - previous_indexes, dtype=arr.dtype)
    interpolation = _linear_interpolation(previous, next_elements, gamma)
    # When an interpolation is in Nan range, (near the end of the sorted array) it means
    # we can clip to the array max value.
    result = np.where(np.isnan(interpolation), arr[np.intp(valid_values_count) - 1], interpolation)
    return result


@guvectorize(
    [(float32[:], float32, float32[:]), (float64[:], float64, float64[:])],
    "(n),()->()",
    nopython=True,
    cache=True,
)
def _vecquantiles(arr, rnk, res):
    if np.isnan(rnk):
        res[0] = np.nan
    else:
        res[0] = np.nanquantile(arr, rnk)


[docs] def vecquantiles(da: DataArray, rnk: DataArray, dim: str | Sequence[Hashable]) -> DataArray: """ For when the quantile (rnk) is different for each point. da and rnk must share all dimensions but dim. Parameters ---------- da : xarray.DataArray The data to compute the quantiles on. rnk : xarray.DataArray The quantiles to compute. dim : str or sequence of str The dimension along which to compute the quantiles. Returns ------- xarray.DataArray The quantiles computed along the `dim` dimension. """ tem = utils.get_temp_dimname(da.dims, "temporal") dims = [dim] if isinstance(dim, str) else dim da = da.stack({tem: dims}) da = da.transpose(*rnk.dims, tem) res = DataArray( _vecquantiles(da.values, rnk.values), dims=rnk.dims, coords=rnk.coords, attrs=da.attrs, ).astype(da.dtype) return res
@njit def _wrapper_quantile1d(arr, q): out = np.empty((arr.shape[0], q.size), dtype=arr.dtype) for index in range(out.shape[0]): out[index] = _nan_quantile_1d(arr[index], q) return out def _quantile(arr, q, nreduce=None): nreduce = nreduce or arr.ndim if arr.ndim == nreduce: out = _nan_quantile_1d(arr.flatten(), q) else: # dimensions that are reduced by quantile red_axis = np.arange(len(arr.shape) - nreduce, len(arr.shape)) reduction_dim_size = np.prod([arr.shape[idx] for idx in red_axis]) # kept dimensions keep_axis = np.arange(len(arr.shape) - nreduce) final_shape = [arr.shape[idx] for idx in keep_axis] + [len(q)] # reshape as (keep_dims, red_dims), compute, reshape back arr = arr.reshape(-1, reduction_dim_size) out = _wrapper_quantile1d(arr, q) out = out.reshape(final_shape) return out
[docs] def quantile(da: DataArray, q: np.ndarray, dim: str | Sequence[Hashable]) -> DataArray: """ Compute the quantiles from a fixed list `q`. Parameters ---------- da : xarray.DataArray The data to compute the quantiles on. q : array-like The quantiles to compute. dim : str or sequence of str The dimension along which to compute the quantiles. Returns ------- xarray.DataArray The quantiles computed along the `dim` dimension. """ if USE_FASTNANQUANTILE is True: if len(q) <= 1000: return xr_apply_nanquantile(da, dim=dim, q=q).rename({"quantile": "quantiles"}) else: warnings.warn( "`fastnanquantile` is installed and would thus normally be used by default. However, it doesn't " f"work with more than 1000 quantiles (`len(q) = {len(q)}` was given). `xsdba` built-in functions will " "be used instead.", stacklevel=2, ) qc = np.array(q, dtype=da.dtype) dims = [dim] if isinstance(dim, str) else dim kwargs = {"nreduce": len(dims), "q": qc} res = ( apply_ufunc( _quantile, da, input_core_dims=[dims], exclude_dims=set(dims), output_core_dims=[["quantiles"]], output_dtypes=[da.dtype], dask_gufunc_kwargs={"output_sizes": {"quantiles": len(q)}}, dask="parallelized", kwargs=kwargs, ) .assign_coords(quantiles=q) .assign_attrs(da.attrs) ) return res
[docs] @njit( [ float32[:, :](float32[:, :]), float64[:, :](float64[:, :]), ], fastmath=False, nogil=True, cache=True, ) def remove_NaNs(x): # noqa: N802 """Remove NaN values from series.""" remove = np.zeros_like(x[0, :], dtype=boolean) for i in range(x.shape[0]): remove = remove | np.isnan(x[i, :]) return x[:, ~remove]
@njit( [ float32(float32[:, :], float32[:, :]), float64(float64[:, :], float64[:, :]), ], fastmath=True, nogil=True, cache=True, ) def _correlation(X, Y): """ Compute a correlation as the mean of pairwise distances between points in X and Y. X is KxN and Y is KxM, the result is the mean of the MxN distances. Similar to scipy.spatial.distance.cdist(X, Y, 'euclidean') """ d = 0 for i in range(X.shape[1]): for j in range(Y.shape[1]): d1 = 0 for k in range(X.shape[0]): d1 += (X[k, i] - Y[k, j]) ** 2 d += np.sqrt(d1) return d / (X.shape[1] * Y.shape[1]) @njit( [ float32(float32[:, :]), float64(float64[:, :]), ], fastmath=True, nogil=True, cache=True, ) def _autocorrelation(X): """ Mean of the NxN pairwise distances of points in X of shape KxN. Similar to scipy.spatial.distance.pdist(..., 'euclidean') """ d = 0 for i in range(X.shape[1]): for j in range(i): d1 = 0 for k in range(X.shape[0]): d1 += (X[k, i] - X[k, j]) ** 2 d += np.sqrt(d1) return (2 * d) / X.shape[1] ** 2 @guvectorize( [ (float32[:, :], float32[:, :], float32[:]), (float64[:, :], float64[:, :], float64[:]), ], "(k, n),(k, m)->()", nopython=True, cache=True, ) def _escore(tgt, sim, out): """ E-score based on the Székely-Rizzo e-distances between clusters. tgt and sim are KxN and KxM, where dimensions are along K and observations along M and N. When N > 0, only this many points of target and sim are used, taken evenly distributed in the series. When std is True, X and Y are standardized according to the nanmean and nanstd (ddof = 1) of X. """ sim = remove_NaNs(sim) tgt = remove_NaNs(tgt) n1 = sim.shape[1] n2 = tgt.shape[1] if 0 in [n1, n2]: out[0] = np.nan else: sXY = _correlation(tgt, sim) sXX = _autocorrelation(tgt) sYY = _autocorrelation(sim) w = n1 * n2 / (n1 + n2) out[0] = w * (sXY + sXY - sXX - sYY) / 2 @njit( fastmath=False, nogil=True, cache=True, ) def _first_and_last_nonnull(arr): """For each row of arr, get the first and last non NaN elements.""" out = np.empty((arr.shape[0], 2)) for i in range(arr.shape[0]): idxs = np.where(~np.isnan(arr[i]))[0] if idxs.size > 0: out[i] = arr[i][idxs[np.array([0, -1])]] else: out[i] = np.array([np.nan, np.nan]) return out @njit( fastmath=False, nogil=True, cache=True, ) def _extrapolate_on_quantiles(interp, oldx, oldg, oldy, newx, newg, method="constant"): """ Apply extrapolation to the output of interpolation on quantiles with a given grouping. Arguments are the same as _interp_on_quantiles_2d. """ bnds = _first_and_last_nonnull(oldx) xp = oldg[:, 0] toolow = newx < np.interp(newg, xp, bnds[:, 0]) toohigh = newx > np.interp(newg, xp, bnds[:, 1]) if method == "constant": constants = _first_and_last_nonnull(oldy) cnstlow = np.interp(newg, xp, constants[:, 0]) cnsthigh = np.interp(newg, xp, constants[:, 1]) interp[toolow] = cnstlow[toolow] interp[toohigh] = cnsthigh[toohigh] else: # 'nan' interp[toolow] = np.nan interp[toohigh] = np.nan return interp @njit( fastmath=False, nogil=True, cache=True, ) def _pairwise_haversine_and_bins(lond, latd, transpose=False): """Inter-site distances with the haversine approximation.""" N = lond.shape[0] lon = np.deg2rad(lond) lat = np.deg2rad(latd) dists = np.full((N, N), np.nan) for i in range(N - 1): for j in range(i + 1, N): dlon = lon[j] - lon[i] dists[i, j] = 6367 * np.arctan2( np.sqrt( (np.cos(lat[j]) * np.sin(dlon)) ** 2 + (np.cos(lat[i]) * np.sin(lat[j]) - np.sin(lat[i]) * np.cos(lat[j]) * np.cos(dlon)) ** 2 ), np.sin(lat[i]) * np.sin(lat[j]) + np.cos(lat[i]) * np.cos(lat[j]) * np.cos(dlon), ) if transpose: dists[j, i] = dists[i, j] mn = np.nanmin(dists) mx = np.nanmax(dists) if transpose: np.fill_diagonal(dists, 0) return dists, mn, mx