82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
from typing import Sequence
|
|
import numpy as np
|
|
|
|
|
|
def _validate_carry(arr: np.ndarray) -> tuple[int, tuple[int, int]]:
|
|
"""
|
|
Ensure that `arr` is interpretable as an explicit carry rule.
|
|
See `Carry` for more information.
|
|
"""
|
|
overflow = None
|
|
over_pos = None
|
|
coeff_sum = 0
|
|
# find largest digit for carry
|
|
for i, row in enumerate(arr):
|
|
for j, val in enumerate(row):
|
|
coeff_sum += val
|
|
if val < 0:
|
|
if overflow is not None:
|
|
raise ValueError("Array supplied has more than one overflow digit")
|
|
overflow = -val
|
|
over_pos = (i, j)
|
|
|
|
if coeff_sum < 0:
|
|
raise ValueError("Sum of coefficients too small")
|
|
|
|
if overflow is None or over_pos is None:
|
|
raise ValueError("No dominating term supplied")
|
|
|
|
return overflow, over_pos
|
|
|
|
|
|
class Carry:
|
|
"""
|
|
Two-dimensional carry rule.
|
|
|
|
This is an (integer) Numpy array which is greedily "applied" to another (integer) Numpy array to
|
|
rid it of integers.
|
|
Only "explicit" carries are allowed -- these are carries that are all negative except for a
|
|
single positive term, such that the sum of all terms is non-negative.
|
|
|
|
For example, the Laplacian kernel, as a carry rule is:
|
|
`laplace = Carry([[0,1,0],[1,-4,1],[0,1,0]])`
|
|
|
|
Interpreted as a carry rule, this spreads integers above 4 into their neighboring cells.
|
|
"""
|
|
|
|
def __init__(self, arr: np.ndarray | Sequence):
|
|
if not isinstance(arr, np.ndarray):
|
|
arr = np.array(arr)
|
|
overflow, over_pos = _validate_carry(arr)
|
|
|
|
self._arr: np.ndarray = arr
|
|
self.over_pos: tuple[int, int] = over_pos
|
|
self.overflow: int = overflow
|
|
|
|
def _apply(
|
|
self,
|
|
base: np.ndarray,
|
|
update_tuples: list[tuple[int, int]],
|
|
zerowall: np.ndarray,
|
|
):
|
|
"""Apply the carry at the locations in `base` specified by `update_tuples`."""
|
|
pos_x, pos_y = self.over_pos
|
|
for x, y in update_tuples:
|
|
idx = base[(x, y)]
|
|
dec = idx // self.overflow # - 1
|
|
|
|
wall = zerowall.copy()
|
|
sh_x, sh_y = self._arr.shape
|
|
x -= pos_x
|
|
y -= pos_y
|
|
|
|
wall[x : x + sh_x, y : y + sh_y] = dec * self._arr
|
|
base += wall
|
|
|
|
def apply(self, base: np.ndarray) -> np.ndarray:
|
|
"""Update an expansion `base` according to this carry."""
|
|
zerowall = np.zeros(base.shape, dtype=base.dtype)
|
|
while update_tuples := list(zip(*np.where(base >= self.overflow))):
|
|
self._apply(base, update_tuples, zerowall)
|
|
return base
|