Module lib.tree_utils
Expand source code
import jax
import jax.numpy as jnp
# https://docs.liesel-project.org/en/v0.1.4/_modules/liesel/goose/pytree.html#stack_leaves
def stack_leaves(pytrees, axis: int=0):
'''
Stack the leaves of one or more PyTrees along a new axis.
Args:
pytrees: One or more PyTrees.
axis (int, optional): The axis along which the arrays will be stacked. Default is 0.
Returns:
The PyTree with its leaves stacked along the new axis.
'''
return jax.tree_map(lambda *xs: jnp.stack(xs, axis=axis), *pytrees)
# https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75?permalink_comment_id=4634557#gistcomment-4634557
def unstack_leaves(pytrees):
'''
Unstack the leaves of a PyTree.
Args:
pytrees: A PyTree.
Returns:
A list of PyTrees, where each PyTree has the same structure as the input PyTree, but each leaf contains only one part of the original leaf.
'''
leaves, treedef = jax.tree_util.tree_flatten(pytrees)
return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]
Functions
def stack_leaves(pytrees, axis: int = 0)
-
Stack the leaves of one or more PyTrees along a new axis.
Args
pytrees
- One or more PyTrees.
axis
:int
, optional- The axis along which the arrays will be stacked. Default is 0.
Returns
The PyTree with its leaves stacked along the new axis.
Expand source code
def stack_leaves(pytrees, axis: int=0): ''' Stack the leaves of one or more PyTrees along a new axis. Args: pytrees: One or more PyTrees. axis (int, optional): The axis along which the arrays will be stacked. Default is 0. Returns: The PyTree with its leaves stacked along the new axis. ''' return jax.tree_map(lambda *xs: jnp.stack(xs, axis=axis), *pytrees)
def unstack_leaves(pytrees)
-
Unstack the leaves of a PyTree.
Args
pytrees
- A PyTree.
Returns
A list of PyTrees, where each PyTree has the same structure as the input PyTree, but each leaf contains only one part of the original leaf.
Expand source code
def unstack_leaves(pytrees): ''' Unstack the leaves of a PyTree. Args: pytrees: A PyTree. Returns: A list of PyTrees, where each PyTree has the same structure as the input PyTree, but each leaf contains only one part of the original leaf. ''' leaves, treedef = jax.tree_util.tree_flatten(pytrees) return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]