Module lib.llama.embedding

Expand source code
import math

from jax import Array
import jax.random as rand

from .ModelConfig import ModelConfig

def check_embedding(params: Array, *, model_config: ModelConfig) -> None:
    assert isinstance(params, Array)
    assert params.shape == (model_config.vocab_size, model_config.d_model)

def init_embedding(*, key: Array, model_config: ModelConfig) -> Array:
    upper = 1. / math.sqrt(model_config.d_model)
    return rand.truncated_normal(key, -upper, upper, (model_config.vocab_size, model_config.d_model))

def forward_embedding(params: Array, x: Array) -> Array:
    return params[x]

Functions

def check_embedding(params: jax.Array, *, model_config: ModelConfig) ‑> None
Expand source code
def check_embedding(params: Array, *, model_config: ModelConfig) -> None:
    assert isinstance(params, Array)
    assert params.shape == (model_config.vocab_size, model_config.d_model)
def init_embedding(*, key: jax.Array, model_config: ModelConfig) ‑> jax.Array
Expand source code
def init_embedding(*, key: Array, model_config: ModelConfig) -> Array:
    upper = 1. / math.sqrt(model_config.d_model)
    return rand.truncated_normal(key, -upper, upper, (model_config.vocab_size, model_config.d_model))
def forward_embedding(params: jax.Array, x: jax.Array) ‑> jax.Array
Expand source code
def forward_embedding(params: Array, x: Array) -> Array:
    return params[x]