Module lib.llama.kv_cache
Expand source code
from typing import Any, NamedTuple
import jax.numpy as jnp
class KVCache(NamedTuple):
k_cache: Any # Array
v_cache: Any # Array
def shift_left_kv_cache(kv_cache: KVCache) -> KVCache:
k_cache, v_cache = kv_cache
k_cache = jnp.roll(k_cache, -1, axis=-2) # -2: dimension L
v_cache = jnp.roll(v_cache, -1, axis=-2) # -2: dimension L
return KVCache(k_cache, v_cache)
Functions
def shift_left_kv_cache(kv_cache: KVCache) ‑> KVCache
-
Expand source code
def shift_left_kv_cache(kv_cache: KVCache) -> KVCache: k_cache, v_cache = kv_cache k_cache = jnp.roll(k_cache, -1, axis=-2) # -2: dimension L v_cache = jnp.roll(v_cache, -1, axis=-2) # -2: dimension L return KVCache(k_cache, v_cache)
Classes
class KVCache (k_cache: Any, v_cache: Any)
-
KVCache(k_cache, v_cache)
Expand source code
class KVCache(NamedTuple): k_cache: Any # Array v_cache: Any # Array
Ancestors
- builtins.tuple
Instance variables
var k_cache : Any
-
Alias for field number 0
var v_cache : Any
-
Alias for field number 1