Module lib.rand_utils

Expand source code
from itertools import repeat

from jax import Array
import jax.random as rand

def split_key_nullable(key: Array | None, num: int=2):
    if key is None:
        return tuple(repeat(None, num))
    return rand.split(key, num)

Functions

def split_key_nullable(key: jax.Array | None, num: int = 2)
Expand source code
def split_key_nullable(key: Array | None, num: int=2):
    if key is None:
        return tuple(repeat(None, num))
    return rand.split(key, num)