from typing import Sequence
import numpy as np
import numba
from regridding import _util
__all__ = [
"regrid_from_weights",
]
[docs]
def regrid_from_weights(
weights: np.ndarray,
shape_input: tuple[int, ...],
shape_output: tuple[int, ...],
values_input: np.ndarray,
values_output: None | np.ndarray = None,
axis_input: None | int | Sequence[int] = None,
axis_output: None | int | Sequence[int] = None,
) -> np.ndarray:
"""
Regrid an array of values using weights computed by
:func:`regridding.weights`.
Parameters
----------
weights
Ragged array of weights computed by :func:`regridding.weights`.
shape_input
Broadcasted shape of the input coordinates computed by :func:`regridding.weights`.
shape_output
Broadcasted shape of the output coordinates computed by :func:`regridding.weights`.
values_input
Input array of values to be resampled.
values_output
Optional array in which to place the output.
axis_input
Logical axes of the input array to resample.
If :obj:`None`, resample all the axes of the input array.
The number of axes should be equal to the number of
coordinates in the original input grid passed to :func:`regridding.weights`.
axis_output
Logical axes of the output array corresponding to the resampled axes
of the input array.
If :obj:`None`, all the axes of the output array correspond to resampled
axes in the input grid.
The number of axes should be equal to the original number of
coordinates in the output grid passed to :func:`regridding.weights`.
See Also
--------
:func:`regridding.regrid`
:func:`regridding.weights`
"""
unit = getattr(values_input, "unit", None)
ndim_input = len(shape_input)
ndim_output = len(shape_output)
axis_input = _util._normalize_axis(axis_input, ndim=ndim_input)
axis_output = _util._normalize_axis(axis_output, ndim=ndim_output)
shape_input_orthogonal = tuple(
shape_input[i]
for i in _util._normalize_axis(None, ndim=len(shape_input))
if i not in axis_input
)
shape_output_orthogonal = tuple(
shape_output[i]
for i in _util._normalize_axis(None, ndim=len(shape_output))
if i not in axis_output
)
if np.ndim(values_input) > 0:
shape_values_orthogonal = tuple(
values_input.shape[i]
for i in _util._normalize_axis(None, ndim=values_input.ndim)
if i not in axis_input
)
else:
shape_values_orthogonal = ()
shape_orthogonal = np.broadcast_shapes(
shape_input_orthogonal,
shape_output_orthogonal,
shape_values_orthogonal,
)
axis_input = tuple(sorted(axis_input))
axis_output = tuple(sorted(axis_output))
shape_input_new = list(reversed(shape_orthogonal))
for ax in reversed(axis_input):
shape_input_new.insert(~ax, shape_input[ax])
shape_input = tuple(reversed(shape_input_new))
shape_output_new = list(reversed(shape_orthogonal))
for ax in reversed(axis_output):
shape_output_new.insert(~ax, shape_output[ax])
shape_output = tuple(reversed(shape_output_new))
weights = np.broadcast_to(np.array(weights), shape_orthogonal, subok=True)
values_input = np.broadcast_to(values_input, shape_input, subok=True)
if values_output is None:
values_output = np.zeros_like(values_input, shape=shape_output, dtype=float)
else:
if values_output.shape != shape_output: # pragma: nocover
raise ValueError(
f"{values_output.shape=} should be equal to {shape_output}"
)
values_output.fill(0)
axis_input_numba = ~np.arange(len(axis_input))[::-1]
axis_output_numba = ~np.arange(len(axis_output))[::-1]
shape_input_numba = tuple(shape_input[ax] for ax in axis_input)
shape_output_numba = tuple(shape_output[ax] for ax in axis_output)
values_input = np.moveaxis(values_input, axis_input, axis_input_numba)
values_output = np.moveaxis(values_output, axis_output, axis_output_numba)
shape_output_tmp = values_output.shape
values_input = values_input.reshape(-1, *shape_input_numba)
values_output = values_output.reshape(-1, *shape_output_numba)
weights = numba.typed.List(weights.reshape(-1))
values_input = np.ascontiguousarray(values_input)
values_output = np.ascontiguousarray(values_output)
_regrid_from_weights(
weights=weights,
values_input=values_input,
values_output=values_output,
)
values_output = values_output.reshape(*shape_output_tmp)
values_output = np.moveaxis(values_output, axis_output_numba, axis_output)
if unit is not None:
values_output = values_output << unit
return values_output
@numba.njit(cache=True, parallel=True)
def _regrid_from_weights(
weights: numba.typed.List,
values_input: np.ndarray,
values_output: np.ndarray,
) -> None:
for d in numba.prange(len(weights)):
d = numba.types.int64(d)
weights_d = weights[d]
values_input_d = values_input[d].reshape(-1)
values_output_d = values_output[d].reshape(-1)
for w in range(len(weights_d)):
i_input, i_output, weight = weights_d[w]
values_output_d[int(i_output)] += weight * values_input_d[int(i_input)]