Module lib.logits_processing

Expand source code
from operator import getitem
from typing import Callable

import jax
from jax import Array
import jax.nn as nn
import jax.numpy as jnp
import jax.random as rand

from ..rand_utils import split_key_nullable

# TODO: need type checking?
# _, seq_len = seq.shape
# assert seq.shape == (batch_size, seq_len)
# assert seq.dtype == jnp.uint16
# assert attn_mask.shape == (batch_size, seq_len)
# assert attn_mask.dtype == jnp.bool_

def PresencePenaltyProcessor(penalty: float) -> Callable:
    def inner(logits: Array, *, seq: Array, attn_mask: Array, **kwargs) -> Array:
        *_, vocab_size = logits.shape
        exists = jax.vmap(lambda row, row_mask: jnp.bincount(row, weights=row_mask.astype(jnp.uint16), length=vocab_size).astype(jnp.bool_))(seq, attn_mask)  # (batch_size, vocab_size)
        return logits - exists * penalty
    return inner

def FrequencyPenaltyProcessor(penalty: float) -> Callable:
    def inner(logits: Array, *, seq: Array, attn_mask: Array, **kwargs) -> Array:
        *_, vocab_size = logits.shape
        counts = jax.vmap(lambda row, row_mask: jnp.bincount(row, weights=row_mask.astype(jnp.uint16), length=vocab_size))(seq, attn_mask)  # (batch_size, vocab_size)
        return logits - counts * penalty
    return inner

def TopKSampler(top_k: int) -> Callable:
    def inner(logits: Array, *, key: Array, **kwargs) -> Array:
        batch_size, _ = logits.shape
        keys = rand.split(key, num=batch_size)

        def inner_inner(logits: Array, key: Array) -> Array:
            values, indices = jax.lax.top_k(logits, k=top_k)
            indices = indices.astype(jnp.uint16)
            selected_index = rand.categorical(key, values)
            selected_token_id = indices[selected_index]
            return selected_token_id

        selected_token_ids = jax.vmap(inner_inner)(logits, keys)
        return selected_token_ids
    setattr(inner, 'requires_key', True)
    return inner

def TopPSampler(top_p: float) -> Callable:
    def inner(logits: Array, *, key: Array, **kwargs) -> Array:
        batch_size, vocab_size = logits.shape

        indices = jnp.broadcast_to(jnp.arange(vocab_size, dtype=jnp.uint16), (batch_size, vocab_size))
        sorted_logits, sorted_indices = jax.lax.sort_key_val(-logits, indices, is_stable=False)
        sorted_logits = -sorted_logits
        sorted_probs = nn.softmax(sorted_logits)
        cum_probs = jnp.cumsum(sorted_probs, axis=-1)
        threshold_probs = jnp.maximum(cum_probs[:, 0], top_p)  # guarantee that at least one token will not be masked
        masked_sorted_logits = jnp.where(cum_probs >= threshold_probs[:, None], -jnp.inf, sorted_logits)

        key, subkey = rand.split(key)
        selected_indices = rand.categorical(subkey, masked_sorted_logits)
        selected_token_ids = jax.vmap(getitem)(sorted_indices, selected_indices)  # type: ignore[call-overload]
        return selected_token_ids
    setattr(inner, 'requires_key', True)
    return inner

def make_logits_processor(*callables):
    def inner(logits: Array, seq: Array, attn_mask: Array, key: Array | None) -> Array:
        for f in callables:
            if not getattr(f, 'requires_key', False):
                logits = f(logits, seq=seq, attn_mask=attn_mask)
            else:
                assert key is not None
                key, subkey = rand.split(key)
                logits = f(logits, seq=seq, attn_mask=attn_mask, key=subkey)
        return logits
    return inner

Functions

def PresencePenaltyProcessor(penalty: float) ‑> Callable
Expand source code
def PresencePenaltyProcessor(penalty: float) -> Callable:
    def inner(logits: Array, *, seq: Array, attn_mask: Array, **kwargs) -> Array:
        *_, vocab_size = logits.shape
        exists = jax.vmap(lambda row, row_mask: jnp.bincount(row, weights=row_mask.astype(jnp.uint16), length=vocab_size).astype(jnp.bool_))(seq, attn_mask)  # (batch_size, vocab_size)
        return logits - exists * penalty
    return inner
def FrequencyPenaltyProcessor(penalty: float) ‑> Callable
Expand source code
def FrequencyPenaltyProcessor(penalty: float) -> Callable:
    def inner(logits: Array, *, seq: Array, attn_mask: Array, **kwargs) -> Array:
        *_, vocab_size = logits.shape
        counts = jax.vmap(lambda row, row_mask: jnp.bincount(row, weights=row_mask.astype(jnp.uint16), length=vocab_size))(seq, attn_mask)  # (batch_size, vocab_size)
        return logits - counts * penalty
    return inner
def TopKSampler(top_k: int) ‑> Callable
Expand source code
def TopKSampler(top_k: int) -> Callable:
    def inner(logits: Array, *, key: Array, **kwargs) -> Array:
        batch_size, _ = logits.shape
        keys = rand.split(key, num=batch_size)

        def inner_inner(logits: Array, key: Array) -> Array:
            values, indices = jax.lax.top_k(logits, k=top_k)
            indices = indices.astype(jnp.uint16)
            selected_index = rand.categorical(key, values)
            selected_token_id = indices[selected_index]
            return selected_token_id

        selected_token_ids = jax.vmap(inner_inner)(logits, keys)
        return selected_token_ids
    setattr(inner, 'requires_key', True)
    return inner
def TopPSampler(top_p: float) ‑> Callable
Expand source code
def TopPSampler(top_p: float) -> Callable:
    def inner(logits: Array, *, key: Array, **kwargs) -> Array:
        batch_size, vocab_size = logits.shape

        indices = jnp.broadcast_to(jnp.arange(vocab_size, dtype=jnp.uint16), (batch_size, vocab_size))
        sorted_logits, sorted_indices = jax.lax.sort_key_val(-logits, indices, is_stable=False)
        sorted_logits = -sorted_logits
        sorted_probs = nn.softmax(sorted_logits)
        cum_probs = jnp.cumsum(sorted_probs, axis=-1)
        threshold_probs = jnp.maximum(cum_probs[:, 0], top_p)  # guarantee that at least one token will not be masked
        masked_sorted_logits = jnp.where(cum_probs >= threshold_probs[:, None], -jnp.inf, sorted_logits)

        key, subkey = rand.split(key)
        selected_indices = rand.categorical(subkey, masked_sorted_logits)
        selected_token_ids = jax.vmap(getitem)(sorted_indices, selected_indices)  # type: ignore[call-overload]
        return selected_token_ids
    setattr(inner, 'requires_key', True)
    return inner
def make_logits_processor(*callables)
Expand source code
def make_logits_processor(*callables):
    def inner(logits: Array, seq: Array, attn_mask: Array, key: Array | None) -> Array:
        for f in callables:
            if not getattr(f, 'requires_key', False):
                logits = f(logits, seq=seq, attn_mask=attn_mask)
            else:
                assert key is not None
                key, subkey = rand.split(key)
                logits = f(logits, seq=seq, attn_mask=attn_mask, key=subkey)
        return logits
    return inner