from functools import partial
from typing import Callable
import jax
import jax.numpy as jnp
import lox
from flax import linen as nn
from flax import struct
from flax.typing import Dtype
from jax.nn.initializers import lecun_normal
from memorax.utils.typing import Array, Carry
from .rnn import RNNCellBase
def _initialize_nu_log(key, shape, r_min=0.0, r_max=1.0):
u = jax.random.uniform(key, shape=shape)
return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2))
def _initialize_theta_log(key, shape, max_phase=6.28):
u = jax.random.uniform(key, shape=shape)
return jnp.log(max_phase * u)
@struct.dataclass
class RTUConfig:
features: int
hidden_dim: int
r_min: float = 0.0
r_max: float = 1.0
max_phase: float = 6.28
eps: float = 1e-8
dtype: Dtype | None = None
param_dtype: Dtype = jnp.float32
activation_fn: Callable = struct.field(pytree_node=False, default=jnp.tanh)
@struct.dataclass
class RTUCarry:
real: Array
imaginary: Array
[docs]
class RTUCell(RNNCellBase):
config: RTUConfig
@property
def num_feature_axes(self) -> int:
return 1
[docs]
def setup(self):
self.nu_log = self.param(
"nu_log",
partial(_initialize_nu_log, r_min=self.config.r_min, r_max=self.config.r_max),
(self.config.hidden_dim,),
)
self.theta_log = self.param(
"theta_log",
partial(_initialize_theta_log, max_phase=self.config.max_phase),
(self.config.hidden_dim,),
)
self.B_real = self.param(
"B_real",
lecun_normal(),
(self.config.hidden_dim, self.config.features),
)
self.B_imag = self.param(
"B_imag",
lecun_normal(),
(self.config.hidden_dim, self.config.features),
)
def _g_phi_norm(self) -> tuple[Array, Array, Array, Array]:
r = jnp.exp(-jnp.exp(self.nu_log))
theta = jnp.exp(self.theta_log)
g = r * jnp.cos(theta)
phi = r * jnp.sin(theta)
norm = jnp.sqrt(1 - r**2) + self.config.eps
return g, phi, norm, r
[docs]
@nn.compact
def __call__(self, carry: RTUCarry, inputs: Array) -> tuple[RTUCarry, Array]:
g, phi, norm, r = self._g_phi_norm()
theta = jnp.exp(self.theta_log)
tau = jnp.exp(-self.nu_log)
lox.log({
"rtu/r": r.mean(),
"rtu/r_min": r.min(),
"rtu/r_max": r.max(),
"rtu/tau": tau.mean(),
"rtu/theta": theta.mean(),
"rtu/theta_std": theta.std(),
})
pre_real = g * carry.real - phi * carry.imaginary + norm * (inputs @ self.B_real.T)
pre_imaginary = g * carry.imaginary + phi * carry.real + norm * (inputs @ self.B_imag.T)
f = self.config.activation_fn
new_carry = RTUCarry(real=f(pre_real), imaginary=f(pre_imaginary))
output = jnp.concatenate([new_carry.real, new_carry.imaginary], axis=-1)
return new_carry, output
[docs]
@nn.nowrap
def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> RTUCarry:
*batch_dims, _ = input_shape
zeros = jnp.zeros((*batch_dims, self.config.hidden_dim))
return RTUCarry(real=zeros, imaginary=zeros)
[docs]
def compute_phantom(self, sensitivity: dict[str, Array]) -> RTUCarry:
params = self.variables["params"]
real_phantom = 0
imaginary_phantom = 0
for name, S in sensitivity.items():
param = params[name]
diff = param - jax.lax.stop_gradient(param)
contribution = jnp.sum(S * diff, axis=tuple(range(3, S.ndim)))
real_phantom = real_phantom + contribution[:, 0]
imaginary_phantom = imaginary_phantom + contribution[:, 1]
return RTUCarry(real=real_phantom, imaginary=imaginary_phantom)
[docs]
def inject_phantom(self, carry: RTUCarry, phantom: RTUCarry) -> RTUCarry:
return RTUCarry(
real=jax.lax.stop_gradient(carry.real) + phantom.real,
imaginary=jax.lax.stop_gradient(carry.imaginary) + phantom.imaginary,
)
[docs]
def local_jacobian(
self,
carry: RTUCarry,
inputs: Array,
sensitivity: dict[str, Array],
**kwargs,
) -> tuple[RTUCarry, Array, dict[str, Array]]:
g, phi, norm, r = self._g_phi_norm()
f = self.config.activation_fn
u_real = inputs @ self.B_real.T
u_imaginary = inputs @ self.B_imag.T
pre_real = g * carry.real - phi * carry.imaginary + norm * u_real
pre_imaginary = g * carry.imaginary + phi * carry.real + norm * u_imaginary
d_real = jax.grad(lambda x: f(x).sum())(pre_real)
d_imaginary = jax.grad(lambda x: f(x).sum())(pre_imaginary)
new_carry = RTUCarry(real=f(pre_real), imaginary=f(pre_imaginary))
output = jnp.concatenate([new_carry.real, new_carry.imaginary], axis=-1)
A = jnp.stack([jnp.stack([g, -phi]), jnp.stack([phi, g])])
d = jnp.stack([d_real, d_imaginary], axis=1)
exp_nu = jnp.exp(self.nu_log)
dg_dnu = -exp_nu * g
dphi_dnu = -exp_nu * phi
dnorm_dnu = exp_nu * r**2 / (jnp.sqrt(1 - r**2) + 1e-12)
theta = jnp.exp(self.theta_log)
dg_dtheta = -phi * theta
dphi_dtheta = g * theta
Bu = jnp.einsum('h,bf->bhf', norm, inputs)
zeros_bhf = jnp.zeros_like(Bu)
jacobians = {
"nu_log": jnp.stack([
dg_dnu * carry.real - dphi_dnu * carry.imaginary + dnorm_dnu * u_real,
dg_dnu * carry.imaginary + dphi_dnu * carry.real + dnorm_dnu * u_imaginary,
], axis=1),
"theta_log": jnp.stack([
dg_dtheta * carry.real - dphi_dtheta * carry.imaginary,
dg_dtheta * carry.imaginary + dphi_dtheta * carry.real,
], axis=1),
"B_real": jnp.stack([Bu, zeros_bhf], axis=1),
"B_imag": jnp.stack([zeros_bhf, Bu], axis=1),
}
next_sensitivity = {}
for name in sensitivity:
S = sensitivity[name]
J = jacobians[name]
rotated = jnp.einsum('ijh,bjh...->bih...', A, S)
next_sensitivity[name] = jnp.einsum('bih,bih...->bih...', d, rotated + J)
return new_carry, output, next_sensitivity
[docs]
def initialize_sensitivity(
self, key: jax.Array, input_shape: tuple[int, ...]
) -> dict[str, Array] | None:
*batch_dims, _ = input_shape
H = self.config.hidden_dim
F = self.config.features
return {
"nu_log": jnp.zeros((*batch_dims, 2, H)),
"theta_log": jnp.zeros((*batch_dims, 2, H)),
"B_real": jnp.zeros((*batch_dims, 2, H, F)),
"B_imag": jnp.zeros((*batch_dims, 2, H, F)),
}