from typing import Sequence import numpy as np def _validate_carry(arr: np.ndarray) -> tuple[int, tuple[int, int]]: """ Ensure that `arr` is interpretable as an explicit carry rule. See `Carry` for more information. """ overflow = None over_pos = None coeff_sum = 0 # find largest digit for carry for i, row in enumerate(arr): for j, val in enumerate(row): coeff_sum += val if val < 0: if overflow is not None: raise ValueError("Array supplied has more than one overflow digit") overflow = -val over_pos = (i, j) if coeff_sum > 0: raise ValueError("Sum of coefficients too small") if overflow is None or over_pos is None: raise ValueError("No dominating term supplied") return overflow, over_pos class Carry: """ Two-dimensional carry rule. This is an (integer) Numpy array which is greedily "applied" to another (integer) Numpy array to rid it of integers. Only "explicit" carries are allowed -- these are carries that are all negative except for a single positive term, such that the sum of all terms is non-negative. For example, the Laplacian kernel, as a carry rule is: `laplace = Carry([[0,1,0],[1,-4,1],[0,1,0]])` Interpreted as a carry rule, this spreads integers above 4 into their neighboring cells. """ def __init__(self, arr: np.ndarray | Sequence): if not isinstance(arr, np.ndarray): arr = np.array(arr) overflow, over_pos = _validate_carry(arr) self._arr: np.ndarray = arr self.over_pos: tuple[int, int] = over_pos self.overflow: int = overflow def _apply( self, base: np.ndarray, update_tuples: list[tuple[int, int]], zerowall: np.ndarray, ): """Apply the carry at the locations in `base` specified by `update_tuples`.""" pos_x, pos_y = self.over_pos for x, y in update_tuples: idx = base[(x, y)] dec = idx // self.overflow # - 1 wall = zerowall.copy() sh_x, sh_y = self._arr.shape x -= pos_x y -= pos_y wall[x : x + sh_x, y : y + sh_y] = dec * self._arr base += wall def apply(self, base: np.ndarray) -> np.ndarray: """Update an expansion `base` according to this carry.""" zerowall = np.zeros(base.shape, dtype=base.dtype) while update_tuples := list(zip(*np.where(base >= self.overflow))): self._apply(base, update_tuples, zerowall) return base