Module lib.llama.llama

Expand source code
from functools import partial
import math
from typing import Any, NamedTuple

import jax
from jax import Array
import jax.random as rand

from .ModelConfig import ModelConfig
from .kv_cache import KVCache
from .llama_model import LlamaModel, check_llama_model, forward_llama_model, init_llama_model
from .rotary_embedding import RotaryValues

class Llama(NamedTuple):
    model: LlamaModel
    lm_head: Any  # Array

def check_llama(params: Llama, *, model_config: ModelConfig) -> None:
    assert isinstance(params.model, LlamaModel)
    assert isinstance(params.lm_head, Array)

    check_llama_model(params.model, model_config=model_config)
    assert params.lm_head.shape == (model_config.d_model, model_config.vocab_size)

def init_llama(*, key: Array, model_config: ModelConfig) -> Llama:
    upper = 1. / math.sqrt(model_config.d_model)
    key0, key1 = rand.split(key)
    model = init_llama_model(key=key0, model_config=model_config)
    lm_head = rand.truncated_normal(key1, -upper, upper, (model_config.d_model, model_config.vocab_size))
    return Llama(model, lm_head)

@partial(jax.jit, static_argnames=('model_config'))
def forward_llama(params: Llama, seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
    outputs, kv_cache = forward_llama_model(params.model, seq, qk_mask, rotary_values=rotary_values, kv_cache=kv_cache, key=key, model_config=model_config)
    logits = outputs @ params.lm_head
    return logits, kv_cache

Functions

def check_llama(params: Llama, *, model_config: ModelConfig) ‑> None
Expand source code
def check_llama(params: Llama, *, model_config: ModelConfig) -> None:
    assert isinstance(params.model, LlamaModel)
    assert isinstance(params.lm_head, Array)

    check_llama_model(params.model, model_config=model_config)
    assert params.lm_head.shape == (model_config.d_model, model_config.vocab_size)
def init_llama(*, key: jax.Array, model_config: ModelConfig) ‑> Llama
Expand source code
def init_llama(*, key: Array, model_config: ModelConfig) -> Llama:
    upper = 1. / math.sqrt(model_config.d_model)
    key0, key1 = rand.split(key)
    model = init_llama_model(key=key0, model_config=model_config)
    lm_head = rand.truncated_normal(key1, -upper, upper, (model_config.d_model, model_config.vocab_size))
    return Llama(model, lm_head)
def forward_llama(params: Llama, seq: jax.Array, qk_mask: jax.Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None = None, key: jax.Array | None = None, model_config: ModelConfig) ‑> tuple[jax.Array, KVCache | None]
Expand source code
@partial(jax.jit, static_argnames=('model_config'))
def forward_llama(params: Llama, seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
    outputs, kv_cache = forward_llama_model(params.model, seq, qk_mask, rotary_values=rotary_values, kv_cache=kv_cache, key=key, model_config=model_config)
    logits = outputs @ params.lm_head
    return logits, kv_cache

Classes

class Llama (model: LlamaModel, lm_head: Any)

Llama(model, lm_head)

Expand source code
class Llama(NamedTuple):
    model: LlamaModel
    lm_head: Any  # Array

Ancestors

  • builtins.tuple

Instance variables

var modelLlamaModel

Alias for field number 0

var lm_head : Any

Alias for field number 1