Module lib.llama.kv_cache

Expand source code
from typing import Any, NamedTuple

import jax.numpy as jnp

class KVCache(NamedTuple):
    k_cache: Any  # Array
    v_cache: Any  # Array

def shift_left_kv_cache(kv_cache: KVCache) -> KVCache:
    k_cache, v_cache = kv_cache
    k_cache = jnp.roll(k_cache, -1, axis=-2)  # -2: dimension L
    v_cache = jnp.roll(v_cache, -1, axis=-2)  # -2: dimension L
    return KVCache(k_cache, v_cache)

Functions

def shift_left_kv_cache(kv_cache: KVCache) ‑> KVCache
Expand source code
def shift_left_kv_cache(kv_cache: KVCache) -> KVCache:
    k_cache, v_cache = kv_cache
    k_cache = jnp.roll(k_cache, -1, axis=-2)  # -2: dimension L
    v_cache = jnp.roll(v_cache, -1, axis=-2)  # -2: dimension L
    return KVCache(k_cache, v_cache)

Classes

class KVCache (k_cache: Any, v_cache: Any)

KVCache(k_cache, v_cache)

Expand source code
class KVCache(NamedTuple):
    k_cache: Any  # Array
    v_cache: Any  # Array

Ancestors

  • builtins.tuple

Instance variables

var k_cache : Any

Alias for field number 0

var v_cache : Any

Alias for field number 1