Module lib.llama.llama_model
Expand source code
from functools import partial
from typing import Any, NamedTuple
import jax
from jax import Array
import jax.numpy as jnp
import jax.random as rand
from .ModelConfig import ModelConfig
from .decoder import Decoder, check_decoder, forward_decoder, init_decoder
from .embedding import check_embedding, forward_embedding, init_embedding
from .kv_cache import KVCache
from .rms_norm import check_rms_norm, forward_rms_norm, init_rms_norm
from .rotary_embedding import RotaryValues
class LlamaModel(NamedTuple):
embedding: Any # Array
decoder: Decoder
norm: Any # Array
def check_llama_model(params: LlamaModel, *, model_config: ModelConfig) -> None:
assert isinstance(params.embedding, Array)
assert isinstance(params.decoder, Decoder)
assert isinstance(params.norm, Array)
check_embedding(params.embedding, model_config=model_config)
check_decoder(params.decoder, model_config=model_config)
check_rms_norm(params.norm, model_config=model_config)
def init_llama_model(*, key: Array, model_config: ModelConfig) -> LlamaModel:
key0, key1 = rand.split(key)
embedding = init_embedding(key=key0, model_config=model_config)
decoder = init_decoder(key=key1, model_config=model_config)
norm = init_rms_norm(model_config=model_config)
return LlamaModel(embedding, decoder, norm)
@partial(jax.jit, static_argnames=('model_config'))
def forward_llama_model(params: LlamaModel, seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
assert isinstance(seq, Array)
assert isinstance(qk_mask, Array)
assert seq.dtype == jnp.uint16
assert qk_mask.dtype == jnp.bool_
assert model_config.d_k % 2 == 0
assert key is None or model_config.dropout_rate is not None
seq = forward_embedding(params.embedding, seq)
seq, kv_cache = forward_decoder(params.decoder, seq, qk_mask, rotary_values=rotary_values, kv_cache=kv_cache, key=key, model_config=model_config)
seq = forward_rms_norm(params.norm, seq, model_config=model_config)
return seq, kv_cache
Functions
def check_llama_model(params: LlamaModel, *, model_config: ModelConfig) ‑> None
-
Expand source code
def check_llama_model(params: LlamaModel, *, model_config: ModelConfig) -> None: assert isinstance(params.embedding, Array) assert isinstance(params.decoder, Decoder) assert isinstance(params.norm, Array) check_embedding(params.embedding, model_config=model_config) check_decoder(params.decoder, model_config=model_config) check_rms_norm(params.norm, model_config=model_config)
def init_llama_model(*, key: jax.Array, model_config: ModelConfig) ‑> LlamaModel
-
Expand source code
def init_llama_model(*, key: Array, model_config: ModelConfig) -> LlamaModel: key0, key1 = rand.split(key) embedding = init_embedding(key=key0, model_config=model_config) decoder = init_decoder(key=key1, model_config=model_config) norm = init_rms_norm(model_config=model_config) return LlamaModel(embedding, decoder, norm)
def forward_llama_model(params: LlamaModel, seq: jax.Array, qk_mask: jax.Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None = None, key: jax.Array | None = None, model_config: ModelConfig) ‑> tuple[jax.Array, KVCache | None]
-
Expand source code
@partial(jax.jit, static_argnames=('model_config')) def forward_llama_model(params: LlamaModel, seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]: assert isinstance(seq, Array) assert isinstance(qk_mask, Array) assert seq.dtype == jnp.uint16 assert qk_mask.dtype == jnp.bool_ assert model_config.d_k % 2 == 0 assert key is None or model_config.dropout_rate is not None seq = forward_embedding(params.embedding, seq) seq, kv_cache = forward_decoder(params.decoder, seq, qk_mask, rotary_values=rotary_values, kv_cache=kv_cache, key=key, model_config=model_config) seq = forward_rms_norm(params.norm, seq, model_config=model_config) return seq, kv_cache
Classes
class LlamaModel (embedding: Any, decoder: DecoderBlock, norm: Any)
-
LlamaModel(embedding, decoder, norm)
Expand source code
class LlamaModel(NamedTuple): embedding: Any # Array decoder: Decoder norm: Any # Array
Ancestors
- builtins.tuple
Instance variables
var embedding : Any
-
Alias for field number 0
var decoder : DecoderBlock
-
Alias for field number 1
var norm : Any
-
Alias for field number 2