from __future__ import annotations
import math
import numpy as np
import numba
__all__ = [
"ndarray_linear_interpolation",
]
[docs]
def ndarray_linear_interpolation(
a: np.ndarray,
indices: tuple[np.ndarray, ...],
axis: None | int | tuple[int, ...] = None,
axis_indices: None | int | tuple[int, ...] = None,
):
"""
Interpolate a :class:`numpy.ndarray` onto a new grid.
Similar to :func:`scipy.ndimage.map_coordinates`, but allows for interpolation along only some of the
axes of ``a``.
Parameters
----------
a
The input array to be interpolated. Should have at least as many axes as ``len(indices)``
indices
The new indices where ``a`` will be evaluated. The number of indices, ``len(indices)`` should be equal
to the number of interpolation axes, ``len(axis)``
axis
The axes of ``a`` to interpolate. If :class:`None`, interpolate over all the axes of ``a``.
axis_indices
The axes of ``indices`` that are complementary to the axes in ``axis``.
If :class:`None`, all axes of indices are complementary to the axes in ``axis``.
Returns
-------
The interpolated array
Examples
--------
.. jupyter-execute::
import numpy as np
import matplotlib.pyplot as plt
import regridding
x = np.arange(5)
y = np.arange(6)
Interpolate a 1D array
.. jupyter-execute::
a = x * x
x_new = np.linspace(0, x.size - 1, num=20)
a_new = regridding.ndarray_linear_interpolation(a, indices=(x_new, ))
plt.figure(figsize=(6, 3));
plt.scatter(x, a, s=100, label="original", zorder=1);
plt.scatter(x_new, a_new, label="interpolated", zorder=0);
plt.legend();
"""
shape_a = a.shape
ndim_a = a.ndim
indices_broadcasted = np.broadcast(*indices)
shape_indices = indices_broadcasted.shape
ndim_indices = indices_broadcasted.ndim
if axis is None:
axis = tuple(range(ndim_a))
axis = np.lib.array_utils.normalize_axis_tuple(axis, ndim=ndim_a)
axis = tuple(~(~np.array(axis) % ndim_a))
if axis_indices is None:
axis_indices = tuple(range(ndim_indices))
axis_indices = np.lib.array_utils.normalize_axis_tuple(
axis=axis_indices,
ndim=ndim_indices,
)
axis_indices = tuple(~(~np.array(axis_indices) % ndim_indices))
if len(indices) != len(axis):
raise ValueError(
f"The number of indices, {len(indices)}, must match the number of elements in axis, {len(axis)}"
)
axis_orthogonal_a = tuple(ax for ax in range(-ndim_a, 0) if ax not in axis)
axis_orthogonal_indices = tuple(
ax for ax in range(-ndim_indices, 0) if ax not in axis_indices
)
shape_orthogonal_a = tuple(shape_a[ax] for ax in axis_orthogonal_a)
shape_orthogonal_indices = tuple(
shape_indices[ax] for ax in axis_orthogonal_indices
)
shape_orthogonal = np.broadcast_shapes(shape_orthogonal_a, shape_orthogonal_indices)
ndim_broadcasted_a = len(shape_orthogonal) + len(axis)
ndim_broadcasted_indices = len(shape_orthogonal) + len(axis_indices)
shape_broadcasted_a = tuple(
shape_a[ax] if ax in axis else shape_orthogonal[axis_orthogonal_a.index(ax)]
for ax in range(-ndim_broadcasted_a, 0)
)
shape_broadcasted_indices = tuple(
(
shape_indices[ax]
if ax in axis_indices
else shape_orthogonal[axis_orthogonal_indices.index(ax)]
)
for ax in range(-ndim_broadcasted_indices, 0)
)
a = np.broadcast_to(a, shape_broadcasted_a)
indices = tuple(np.broadcast_to(ind, shape_broadcasted_indices) for ind in indices)
result = np.empty(shape_broadcasted_indices)
def index_a_indices(index: tuple[int, ...]) -> tuple[tuple, tuple]:
index_a = tuple(
slice(None) if ax in axis else index[axis_orthogonal_a.index(ax)]
for ax in range(-ndim_broadcasted_a, 0)
)
index_indices = tuple(
(
slice(None)
if ax in axis_indices
else index[axis_orthogonal_indices.index(ax)]
)
for ax in range(-ndim_broadcasted_indices, 0)
)
return index_a, index_indices
if len(axis) == 1:
(x,) = indices
for index in np.ndindex(*shape_orthogonal):
index_a, index_indices = index_a_indices(index)
result[index_indices] = _ndarray_linear_interpolation_1d(
a=a[index_a],
x=x[index_indices].reshape(-1),
).reshape(x[index_indices].shape)
elif len(axis) == 2:
x, y = indices
for index in np.ndindex(*shape_orthogonal):
index_a, index_indices = index_a_indices(index)
result[index_indices] = _ndarray_linear_interpolation_2d(
a=a[index_a],
x=x[index_indices].reshape(-1),
y=y[index_indices].reshape(-1),
).reshape(x[index_indices].shape)
else:
raise NotImplementedError
return result
@numba.njit(cache=True, parallel=True)
def _ndarray_linear_interpolation_1d(
a: np.ndarray,
x: np.ndarray,
) -> np.ndarray:
shape_output = x.shape
(size_output,) = shape_output
result = np.empty(size_output)
for i in numba.prange(size_output):
result[i] = _linear_interpolation(a, x[i])
return result
@numba.njit(cache=True, parallel=True)
def _ndarray_linear_interpolation_2d(
a: np.ndarray,
x: np.ndarray,
y: np.ndarray,
) -> np.ndarray:
shape_output = x.shape
(size_output,) = shape_output
result = np.empty(size_output)
for i in numba.prange(size_output):
result[i] = _bilinear_interpolation(a, x[i], y[i])
return result
@numba.njit(cache=True)
def _linear_interpolation(
a: np.ndarray,
x: float,
) -> float:
(shape_input_x,) = a.shape
x_0 = int(math.floor(x))
if x_0 < 0:
x_0 = 0
elif x_0 > shape_input_x - 2:
x_0 = shape_input_x - 2
x_1 = x_0 + 1
a_0 = a[x_0]
a_1 = a[x_1]
dx = x - x_0
return a_0 * (1 - dx) + a_1 * dx
@numba.njit(cache=True)
def _bilinear_interpolation(
a: np.ndarray,
x: float,
y: float,
) -> float:
shape_input_x, shape_input_y = a.shape
x_00 = int(math.floor(x))
y_00 = int(math.floor(y))
if x_00 < 0:
x_00 = 0
elif x_00 > shape_input_x - 2:
x_00 = shape_input_x - 2
if y_00 < 0:
y_00 = 0
elif y_00 > shape_input_y - 2:
y_00 = shape_input_y - 2
x_01 = x_00
x_10 = x_11 = x_00 + 1
y_01 = y_11 = y_00 + 1
y_10 = y_00
a_00 = a[x_00, y_00]
a_01 = a[x_01, y_01]
a_10 = a[x_10, y_10]
a_11 = a[x_11, y_11]
dx = x - x_00
dy = y - y_00
w_00 = (1 - dx) * (1 - dy)
w_01 = (1 - dx) * dy
w_10 = dx * (1 - dy)
w_11 = dx * dy
return (a_00 * w_00) + (a_01 * w_01) + (a_10 * w_10) + (a_11 * w_11)