Module lib.llama.rms_norm
Expand source code
from functools import partial
import jax
from jax import Array
import jax.numpy as jnp
from .ModelConfig import ModelConfig
def check_rms_norm(params: Array, *, model_config: ModelConfig) -> None:
assert isinstance(params, Array)
assert params.shape == (model_config.d_model,)
def init_rms_norm(*, model_config: ModelConfig) -> Array:
return jnp.ones((model_config.d_model,))
# Taken from https://github.com/ztjhz/t5-jax/blob/main/model/layer_norm.py#L23
@partial(jax.jit, static_argnames=('model_config',))
def forward_rms_norm(params: Array, x: Array, *, model_config: ModelConfig) -> Array:
x_rms = jnp.sqrt((x * x).mean(axis=-1, keepdims=True) + model_config.rms_norm_eps)
y = x / x_rms * params
return y
Functions
def check_rms_norm(params: jax.Array, *, model_config: ModelConfig) ‑> None
-
Expand source code
def check_rms_norm(params: Array, *, model_config: ModelConfig) -> None: assert isinstance(params, Array) assert params.shape == (model_config.d_model,)
def init_rms_norm(*, model_config: ModelConfig) ‑> jax.Array
-
Expand source code
def init_rms_norm(*, model_config: ModelConfig) -> Array: return jnp.ones((model_config.d_model,))
def forward_rms_norm(params: jax.Array, x: jax.Array, *, model_config: ModelConfig) ‑> jax.Array
-
Expand source code
@partial(jax.jit, static_argnames=('model_config',)) def forward_rms_norm(params: Array, x: Array, *, model_config: ModelConfig) -> Array: x_rms = jnp.sqrt((x * x).mean(axis=-1, keepdims=True) + model_config.rms_norm_eps) y = x / x_rms * params return y