Module lib.rand_utils
Expand source code
from itertools import repeat
from jax import Array
import jax.random as rand
def split_key_nullable(key: Array | None, num: int=2):
if key is None:
return tuple(repeat(None, num))
return rand.split(key, num)
Functions
def split_key_nullable(key: jax.Array | None, num: int = 2)
-
Expand source code
def split_key_nullable(key: Array | None, num: int=2): if key is None: return tuple(repeat(None, num)) return rand.split(key, num)