82 lines
2.6 KiB
Python

from typing import Sequence
import numpy as np
def _validate_carry(arr: np.ndarray) -> tuple[int, tuple[int, int]]:
"""
Ensure that `arr` is interpretable as an explicit carry rule.
See `Carry` for more information.
"""
overflow = None
over_pos = None
coeff_sum = 0
# find largest digit for carry
for i, row in enumerate(arr):
for j, val in enumerate(row):
coeff_sum += val
if val < 0:
if overflow is not None:
raise ValueError("Array supplied has more than one overflow digit")
overflow = -val
over_pos = (i, j)
if coeff_sum > 0:
raise ValueError("Sum of coefficients too small")
if overflow is None or over_pos is None:
raise ValueError("No dominating term supplied")
return overflow, over_pos
class Carry:
"""
Two-dimensional carry rule.
This is an (integer) Numpy array which is greedily "applied" to another (integer) Numpy array to
rid it of integers.
Only "explicit" carries are allowed -- these are carries that are all negative except for a
single positive term, such that the sum of all terms is non-negative.
For example, the Laplacian kernel, as a carry rule is:
`laplace = Carry([[0,1,0],[1,-4,1],[0,1,0]])`
Interpreted as a carry rule, this spreads integers above 4 into their neighboring cells.
"""
def __init__(self, arr: np.ndarray | Sequence):
if not isinstance(arr, np.ndarray):
arr = np.array(arr)
overflow, over_pos = _validate_carry(arr)
self._arr: np.ndarray = arr
self.over_pos: tuple[int, int] = over_pos
self.overflow: int = overflow
def _apply(
self,
base: np.ndarray,
update_tuples: list[tuple[int, int]],
zerowall: np.ndarray,
):
"""Apply the carry at the locations in `base` specified by `update_tuples`."""
pos_x, pos_y = self.over_pos
for x, y in update_tuples:
idx = base[(x, y)]
dec = idx // self.overflow # - 1
wall = zerowall.copy()
sh_x, sh_y = self._arr.shape
x -= pos_x
y -= pos_y
wall[x : x + sh_x, y : y + sh_y] = dec * self._arr
base += wall
def apply(self, base: np.ndarray) -> np.ndarray:
"""Update an expansion `base` according to this carry."""
zerowall = np.zeros(base.shape, dtype=base.dtype)
while update_tuples := list(zip(*np.where(base >= self.overflow))):
self._apply(base, update_tuples, zerowall)
return base