Source code for HyRex.array_with_padding

import jax.numpy as jnp
from jax import lax
import equinox as eqx


[docs] class array_with_padding(eqx.Module): """ Array container with automatic padding management. Manages arrays with trailing infinite padding, tracking the last valid element for efficient concatenation operations. Attributes: ----------- arr : array Full array including padding elements padding_size : int Number of infinite padding elements at end lastnum : int Index of last valid (non-infinite) element lastval : float Value of last valid element """ arr : jnp.array padding_size : int lastnum : int lastval : jnp.float64 def __init__(self,arr): self.arr = arr self.lastnum = jnp.argmax(jnp.isinf(arr)*1)-1 self.lastval = arr[self.lastnum] self.padding_size = arr.size-jnp.argmax(jnp.isinf(arr)*1) def __call__(self): """ Return the full array including padding. Returns: -------- array Complete array with padding elements """ return self.arr
[docs] def concat(self,other_arr): """ Concatenate with another padded array. Combines two padded arrays by removing padding from the first array and appending the second array, then recomputing padding length. Parameters: ----------- other_arr : array_with_padding Second array to concatenate after this one Returns: -------- array_with_padding New padded array containing concatenated data """ if not isinstance(other_arr, array_with_padding): raise TypeError("Can only concatenate with another array_with_padding instance.") x = self.arr y = other_arr.arr padding_size = self.padding_size z = jnp.ones(x.size + y.size)*jnp.inf # neither of these is a tracer!!! z = z.at[0:x.size].set(x) concatenated_arr = lax.dynamic_update_slice(z,y,[x.size-padding_size]) return array_with_padding(concatenated_arr)