Module lib.llama.decoder_block
Expand source code
from functools import partial
import math
from typing import Any, NamedTuple
import jax
from jax import Array
import jax.random as rand
from ..rand_utils import split_key_nullable
from .ModelConfig import ModelConfig
from .attention import Attention, check_attention, forward_attention, init_attention
from .dropout import forward_dropout
from .kv_cache import KVCache
from .rms_norm import check_rms_norm, forward_rms_norm, init_rms_norm
from .rotary_embedding import RotaryValues
class DecoderBlock(NamedTuple):
input_norm: Any # Array
attention: Attention
post_attn_norm: Any # Array
gate_proj: Any # Array
up_proj: Any # Array
down_proj: Any # Array
def check_decoder_block(params: DecoderBlock, *, model_config: ModelConfig) -> None:
assert isinstance(params.input_norm, Array)
assert isinstance(params.attention, Attention)
assert isinstance(params.post_attn_norm, Array)
assert isinstance(params.gate_proj, Array)
assert isinstance(params.up_proj, Array)
assert isinstance(params.down_proj, Array)
check_rms_norm(params.input_norm, model_config=model_config)
check_attention(params.attention, model_config=model_config)
check_rms_norm(params.post_attn_norm, model_config=model_config)
assert params.gate_proj.shape == (model_config.d_model, model_config.d_ff)
assert params.up_proj.shape == (model_config.d_model, model_config.d_ff)
assert params.down_proj.shape == (model_config.d_ff, model_config.d_model)
def init_decoder_block(*, key: Array, model_config: ModelConfig) -> DecoderBlock:
upper = 1. / math.sqrt(model_config.d_model)
key0, key1, key2, key3 = rand.split(key, num=4)
input_norm = init_rms_norm(model_config=model_config)
attention = init_attention(key=key0, model_config=model_config)
post_attn_norm = init_rms_norm(model_config=model_config)
gate_proj = rand.truncated_normal(key1, -upper, upper, (model_config.d_model, model_config.d_ff))
up_proj = rand.truncated_normal(key2, -upper, upper, (model_config.d_model, model_config.d_ff))
down_proj = rand.truncated_normal(key3, -upper, upper, (model_config.d_ff, model_config.d_model))
return DecoderBlock(input_norm, attention, post_attn_norm, gate_proj, up_proj, down_proj)
@partial(jax.jit, static_argnames=('model_config',))
def forward_decoder_block(params: DecoderBlock, seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
key0, key1, key2 = split_key_nullable(key, num=3)
seq_ = seq
seq = forward_rms_norm(params.input_norm, seq, model_config=model_config)
seq, kv_cache = forward_attention(params.attention, seq, seq, qk_mask, rotary_values=rotary_values, kv_cache=kv_cache, model_config=model_config)
seq = forward_dropout(seq, key=key0, model_config=model_config)
seq += seq_
seq_ = seq
seq = forward_rms_norm(params.post_attn_norm, seq, model_config=model_config)
ff = jax.nn.silu(seq @ params.gate_proj) * (seq @ params.up_proj)
ff = forward_dropout(ff, key=key1, model_config=model_config)
seq = ff @ params.down_proj
seq = forward_dropout(seq, key=key2, model_config=model_config)
seq += seq_
return seq, kv_cache
Functions
def check_decoder_block(params: DecoderBlock, *, model_config: ModelConfig) ‑> None
-
Expand source code
def check_decoder_block(params: DecoderBlock, *, model_config: ModelConfig) -> None: assert isinstance(params.input_norm, Array) assert isinstance(params.attention, Attention) assert isinstance(params.post_attn_norm, Array) assert isinstance(params.gate_proj, Array) assert isinstance(params.up_proj, Array) assert isinstance(params.down_proj, Array) check_rms_norm(params.input_norm, model_config=model_config) check_attention(params.attention, model_config=model_config) check_rms_norm(params.post_attn_norm, model_config=model_config) assert params.gate_proj.shape == (model_config.d_model, model_config.d_ff) assert params.up_proj.shape == (model_config.d_model, model_config.d_ff) assert params.down_proj.shape == (model_config.d_ff, model_config.d_model)
def init_decoder_block(*, key: jax.Array, model_config: ModelConfig) ‑> DecoderBlock
-
Expand source code
def init_decoder_block(*, key: Array, model_config: ModelConfig) -> DecoderBlock: upper = 1. / math.sqrt(model_config.d_model) key0, key1, key2, key3 = rand.split(key, num=4) input_norm = init_rms_norm(model_config=model_config) attention = init_attention(key=key0, model_config=model_config) post_attn_norm = init_rms_norm(model_config=model_config) gate_proj = rand.truncated_normal(key1, -upper, upper, (model_config.d_model, model_config.d_ff)) up_proj = rand.truncated_normal(key2, -upper, upper, (model_config.d_model, model_config.d_ff)) down_proj = rand.truncated_normal(key3, -upper, upper, (model_config.d_ff, model_config.d_model)) return DecoderBlock(input_norm, attention, post_attn_norm, gate_proj, up_proj, down_proj)
def forward_decoder_block(params: DecoderBlock, seq: jax.Array, qk_mask: jax.Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None = None, key: jax.Array | None = None, model_config: ModelConfig) ‑> tuple[jax.Array, KVCache | None]
-
Expand source code
@partial(jax.jit, static_argnames=('model_config',)) def forward_decoder_block(params: DecoderBlock, seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]: key0, key1, key2 = split_key_nullable(key, num=3) seq_ = seq seq = forward_rms_norm(params.input_norm, seq, model_config=model_config) seq, kv_cache = forward_attention(params.attention, seq, seq, qk_mask, rotary_values=rotary_values, kv_cache=kv_cache, model_config=model_config) seq = forward_dropout(seq, key=key0, model_config=model_config) seq += seq_ seq_ = seq seq = forward_rms_norm(params.post_attn_norm, seq, model_config=model_config) ff = jax.nn.silu(seq @ params.gate_proj) * (seq @ params.up_proj) ff = forward_dropout(ff, key=key1, model_config=model_config) seq = ff @ params.down_proj seq = forward_dropout(seq, key=key2, model_config=model_config) seq += seq_ return seq, kv_cache
Classes
class DecoderBlock (input_norm: Any, attention: Attention, post_attn_norm: Any, gate_proj: Any, up_proj: Any, down_proj: Any)
-
DecoderBlock(input_norm, attention, post_attn_norm, gate_proj, up_proj, down_proj)
Expand source code
class DecoderBlock(NamedTuple): input_norm: Any # Array attention: Attention post_attn_norm: Any # Array gate_proj: Any # Array up_proj: Any # Array down_proj: Any # Array
Ancestors
- builtins.tuple
Instance variables
var input_norm : Any
-
Alias for field number 0
var attention : Attention
-
Alias for field number 1
var post_attn_norm : Any
-
Alias for field number 2
var gate_proj : Any
-
Alias for field number 3
var up_proj : Any
-
Alias for field number 4
var down_proj : Any
-
Alias for field number 5