import matplotlib.pyplot as plt import numpy as np from sympy.abc import x, y from sympy.plotting import plot_implicit from .anim import animate def _first_nonzero(arr: np.ndarray, axis: int): mask = arr!=0 return min(np.where(mask.any(axis=axis), mask.argmax(axis=axis), arr.shape[axis] + 1)) def latex_polynumber( arr: np.ndarray, center: tuple[int, int] | None = None, show_zero: bool = False ): upper_left = _first_nonzero(arr, 0), _first_nonzero(arr, 1) lower_right = ( len(arr) - _first_nonzero(np.flip(arr, 0), 0) - 1, len(arr[1]) - _first_nonzero(np.flip(arr, 1), 1) - 1 ) center_left, center_top = center or (0, 0) if center is not None: # this does not need offset, since we iterate over this range center_top = center[0] # but this does for the array environment argument center_left = center[1] - upper_left[1] num_columns = lower_right[1] - upper_left[1] column_layout = ("c" * center_left) + "|" + ("c" * (num_columns - center_left)) # build array output ret = "\\begin{array}{" + column_layout + "}" for i in range(upper_left[0], lower_right[0] + 1): if i == center_top: ret += " \\hline " ret += " & ".join([ str(arr[i,j] if arr[i,j] != 0 or show_zero else "") for j in range(upper_left[1], lower_right[1] + 1) ]) # next row ret += " \\\\ " return ret + "\\end{array}" def poly_from_array(array): ret = 0 for i, row in enumerate(array): for j, val in enumerate(row): ret += val*(x**i * y**j) return ret def bindfig(fig): def ret(**kwargs): for i, j in kwargs.items(): if i == "figsize": if j is not None: fig.set_figwidth(j[0]) fig.set_figheight(j[1]) continue fig.__dict__["set_" + i](j) return fig return ret def anim_curves(next_func, dims=25, invalid=2, frames=None, interval=200): zero = np.zeros((dims, dims), dtype=np.int32) #zero[0,0] = 1 fig = plt.gcf() #plt.colorbar() #temp = plt.figure #I hate doing it, but there's no other way to get the figure before it's plotted plt.figure = bindfig(fig) @animate(fig, frames, interval=interval) def ret(fr): next_func(zero, invalid) fig.clf() plot_implicit( poly_from_array(zero) - fr, backend='matplotlib', adaptive=False ) plt.title(f"{fr+1}") print(fr) return ret