Module lib.multihost_utils.shard_array

Expand source code
from types import EllipsisType

import jax
from jax import Array
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np

def shard_array(arr: Array, axis: int | EllipsisType) -> Array:
    shape = arr.shape
    devices: np.ndarray = np.array(jax.devices())

    if axis is ...:
        mesh = Mesh(devices, ('a',))
        sharding = NamedSharding(mesh, P(None))
    else:
        sharding_tuple_ = [1] * len(shape)
        sharding_tuple_[axis] = -1
        sharding_tuple = tuple(sharding_tuple_)
        name_tuple = tuple('abcdefghijklmnopqrstuvwxyz'[:len(shape)])

        mesh = Mesh(devices.reshape(sharding_tuple), name_tuple)
        sharding = NamedSharding(mesh, P(*name_tuple))

    xs = [jax.device_put(arr[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()]
    return jax.make_array_from_single_device_arrays(shape, sharding, xs)

Functions

def shard_array(arr: jax.Array, axis: int | ellipsis) ‑> jax.Array
Expand source code
def shard_array(arr: Array, axis: int | EllipsisType) -> Array:
    shape = arr.shape
    devices: np.ndarray = np.array(jax.devices())

    if axis is ...:
        mesh = Mesh(devices, ('a',))
        sharding = NamedSharding(mesh, P(None))
    else:
        sharding_tuple_ = [1] * len(shape)
        sharding_tuple_[axis] = -1
        sharding_tuple = tuple(sharding_tuple_)
        name_tuple = tuple('abcdefghijklmnopqrstuvwxyz'[:len(shape)])

        mesh = Mesh(devices.reshape(sharding_tuple), name_tuple)
        sharding = NamedSharding(mesh, P(*name_tuple))

    xs = [jax.device_put(arr[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()]
    return jax.make_array_from_single_device_arrays(shape, sharding, xs)