Module lib.llama.attention
Expand source code
from functools import partial
import math
from typing import Any, NamedTuple
import einops as op
import jax
from jax import Array
import jax.nn as nn
import jax.numpy as jnp
import jax.random as rand
from .ModelConfig import ModelConfig
from .kv_cache import KVCache
from .rotary_embedding import RotaryValues, forward_rotary_embedding
class Attention(NamedTuple):
q_proj: Any # Array
k_proj: Any # Array
v_proj: Any # Array
out_proj: Any # Array
def check_attention(params: Attention, *, model_config: ModelConfig) -> None:
assert isinstance(params.q_proj, Array)
assert isinstance(params.k_proj, Array)
assert isinstance(params.v_proj, Array)
assert isinstance(params.out_proj, Array)
assert params.q_proj.shape == (model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k)
assert params.k_proj.shape == (model_config.d_model, model_config.n_heads_kv, model_config.d_k)
assert params.v_proj.shape == (model_config.d_model, model_config.n_heads_kv, model_config.d_v)
assert params.out_proj.shape == (model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model)
def init_attention(*, key: Array, model_config: ModelConfig) -> Attention:
upper = 1. / math.sqrt(model_config.d_model)
key0, key1, key2, key3 = rand.split(key, num=4)
q_proj = rand.truncated_normal(key0, -upper, upper, (model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k))
k_proj = rand.truncated_normal(key1, -upper, upper, (model_config.d_model, model_config.n_heads_kv, model_config.d_k))
v_proj = rand.truncated_normal(key2, -upper, upper, (model_config.d_model, model_config.n_heads_kv, model_config.d_v))
out_proj = rand.truncated_normal(key3, -upper, upper, (model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model))
return Attention(q_proj, k_proj, v_proj, out_proj)
@partial(jax.jit, static_argnames=('model_config',))
def forward_attention(params: Attention, src_seq: Array, dst_seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
q = op.einsum(src_seq, params.q_proj, 'B S M, M R H K -> B R H S K')
k = op.einsum(dst_seq, params.k_proj, 'B D M, M H K -> B H D K')
v = op.einsum(dst_seq, params.v_proj, 'B D M, M H V -> B H D V')
q = forward_rotary_embedding(q, rotary_values=rotary_values)
k = forward_rotary_embedding(k, rotary_values=rotary_values)
if kv_cache is not None:
assert src_seq.shape[1] == 1
assert dst_seq.shape[1] == 1
k_cache, v_cache = kv_cache
k = k_cache.at[:, :, -1:].set(k)
v = v_cache.at[:, :, -1:].set(v)
qk = op.einsum(q, k, 'B R H S K, B H D K -> B R H S D')
qk /= math.sqrt(model_config.d_k)
qk = jnp.where(qk_mask, qk, -jnp.inf)
qk = nn.softmax(qk) # TODO: use `where`
# qk = nn.softmax(qk, where=qk_mask, initial=0.)
qk = jnp.where(qk_mask, qk, 0) # TODO: why this line?
qkv = op.einsum(qk, v, 'B R H S D, B H D V -> B R H S V')
out = op.einsum(qkv, params.out_proj, 'B R H S V, R H V M -> B S M')
kv_cache = None if not model_config.return_kv_cache else KVCache(k, v)
return out, kv_cache
Functions
def check_attention(params: Attention, *, model_config: ModelConfig) ‑> None
-
Expand source code
def check_attention(params: Attention, *, model_config: ModelConfig) -> None: assert isinstance(params.q_proj, Array) assert isinstance(params.k_proj, Array) assert isinstance(params.v_proj, Array) assert isinstance(params.out_proj, Array) assert params.q_proj.shape == (model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k) assert params.k_proj.shape == (model_config.d_model, model_config.n_heads_kv, model_config.d_k) assert params.v_proj.shape == (model_config.d_model, model_config.n_heads_kv, model_config.d_v) assert params.out_proj.shape == (model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model)
def init_attention(*, key: jax.Array, model_config: ModelConfig) ‑> Attention
-
Expand source code
def init_attention(*, key: Array, model_config: ModelConfig) -> Attention: upper = 1. / math.sqrt(model_config.d_model) key0, key1, key2, key3 = rand.split(key, num=4) q_proj = rand.truncated_normal(key0, -upper, upper, (model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k)) k_proj = rand.truncated_normal(key1, -upper, upper, (model_config.d_model, model_config.n_heads_kv, model_config.d_k)) v_proj = rand.truncated_normal(key2, -upper, upper, (model_config.d_model, model_config.n_heads_kv, model_config.d_v)) out_proj = rand.truncated_normal(key3, -upper, upper, (model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model)) return Attention(q_proj, k_proj, v_proj, out_proj)
def forward_attention(params: Attention, src_seq: jax.Array, dst_seq: jax.Array, qk_mask: jax.Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None = None, model_config: ModelConfig) ‑> tuple[jax.Array, KVCache | None]
-
Expand source code
@partial(jax.jit, static_argnames=('model_config',)) def forward_attention(params: Attention, src_seq: Array, dst_seq: Array, qk_mask: Array, *, rotary_values: RotaryValues, kv_cache: KVCache | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]: q = op.einsum(src_seq, params.q_proj, 'B S M, M R H K -> B R H S K') k = op.einsum(dst_seq, params.k_proj, 'B D M, M H K -> B H D K') v = op.einsum(dst_seq, params.v_proj, 'B D M, M H V -> B H D V') q = forward_rotary_embedding(q, rotary_values=rotary_values) k = forward_rotary_embedding(k, rotary_values=rotary_values) if kv_cache is not None: assert src_seq.shape[1] == 1 assert dst_seq.shape[1] == 1 k_cache, v_cache = kv_cache k = k_cache.at[:, :, -1:].set(k) v = v_cache.at[:, :, -1:].set(v) qk = op.einsum(q, k, 'B R H S K, B H D K -> B R H S D') qk /= math.sqrt(model_config.d_k) qk = jnp.where(qk_mask, qk, -jnp.inf) qk = nn.softmax(qk) # TODO: use `where` # qk = nn.softmax(qk, where=qk_mask, initial=0.) qk = jnp.where(qk_mask, qk, 0) # TODO: why this line? qkv = op.einsum(qk, v, 'B R H S D, B H D V -> B R H S V') out = op.einsum(qkv, params.out_proj, 'B R H S V, R H V M -> B S M') kv_cache = None if not model_config.return_kv_cache else KVCache(k, v) return out, kv_cache
Classes
class Attention (q_proj: Any, k_proj: Any, v_proj: Any, out_proj: Any)
-
Attention(q_proj, k_proj, v_proj, out_proj)
Expand source code
class Attention(NamedTuple): q_proj: Any # Array k_proj: Any # Array v_proj: Any # Array out_proj: Any # Array
Ancestors
- builtins.tuple
Instance variables
var q_proj : Any
-
Alias for field number 0
var k_proj : Any
-
Alias for field number 1
var v_proj : Any
-
Alias for field number 2
var out_proj : Any
-
Alias for field number 3