Module lib.array_utils
Expand source code
import torch
from jax import Array
import jax.numpy as jnp
import numpy as np
def jax2np(x: Array) -> np.ndarray:
'''
Converts a JAX array into a NumPy array.
Args:
x (Array): JAX array to convert.
Returns:
np.ndarray: Converted NumPy array.
'''
return np.asarray(x)
def np2jax(x: np.ndarray) -> Array:
'''
Converts a NumPy array into a JAX array.
Args:
x (np.ndarray): NumPy array to convert.
Returns:
Array: Converted JAX array.
'''
return jnp.asarray(x)
def pt2np(x: torch.Tensor) -> np.ndarray:
'''
Converts a PyTorch tensor into a NumPy array.
Args:
x (torch.Tensor): PyTorch tensor to convert.
Returns:
np.ndarray: Converted NumPy array.
'''
with torch.no_grad():
return x.numpy()
def np2pt(x: np.ndarray) -> torch.Tensor:
'''
Converts a NumPy array into a PyTorch tensor.
Args:
x (np.ndarray): NumPy array to convert.
Returns:
torch.Tensor: Converted PyTorch tensor.
'''
return torch.from_numpy(x)
def jax2pt(x: Array) -> torch.Tensor:
'''
Converts a JAX array into a PyTorch tensor using NumPy as intermediate.
Args:
x (Array): JAX array to convert.
Returns:
torch.Tensor: Converted PyTorch tensor.
'''
return np2pt(jax2np(x))
def pt2jax(x: torch.Tensor) -> Array:
'''
Converts a PyTorch tensor into a JAX array using NumPy as intermediate.
Args:
x (torch.Tensor): PyTorch tensor to convert.
Returns:
Array: Converted JAX array.
'''
return np2jax(pt2np(x))
Functions
def jax2np(x: jax.Array) ‑> numpy.ndarray
-
Converts a JAX array into a NumPy array.
Args
x
:Array
- JAX array to convert.
Returns
np.ndarray
- Converted NumPy array.
Expand source code
def jax2np(x: Array) -> np.ndarray: ''' Converts a JAX array into a NumPy array. Args: x (Array): JAX array to convert. Returns: np.ndarray: Converted NumPy array. ''' return np.asarray(x)
def np2jax(x: numpy.ndarray) ‑> jax.Array
-
Converts a NumPy array into a JAX array.
Args
x
:np.ndarray
- NumPy array to convert.
Returns
Array
- Converted JAX array.
Expand source code
def np2jax(x: np.ndarray) -> Array: ''' Converts a NumPy array into a JAX array. Args: x (np.ndarray): NumPy array to convert. Returns: Array: Converted JAX array. ''' return jnp.asarray(x)
def pt2np(x: torch.Tensor) ‑> numpy.ndarray
-
Converts a PyTorch tensor into a NumPy array.
Args
x
:torch.Tensor
- PyTorch tensor to convert.
Returns
np.ndarray
- Converted NumPy array.
Expand source code
def pt2np(x: torch.Tensor) -> np.ndarray: ''' Converts a PyTorch tensor into a NumPy array. Args: x (torch.Tensor): PyTorch tensor to convert. Returns: np.ndarray: Converted NumPy array. ''' with torch.no_grad(): return x.numpy()
def np2pt(x: numpy.ndarray) ‑> torch.Tensor
-
Converts a NumPy array into a PyTorch tensor.
Args
x
:np.ndarray
- NumPy array to convert.
Returns
torch.Tensor
- Converted PyTorch tensor.
Expand source code
def np2pt(x: np.ndarray) -> torch.Tensor: ''' Converts a NumPy array into a PyTorch tensor. Args: x (np.ndarray): NumPy array to convert. Returns: torch.Tensor: Converted PyTorch tensor. ''' return torch.from_numpy(x)
def jax2pt(x: jax.Array) ‑> torch.Tensor
-
Converts a JAX array into a PyTorch tensor using NumPy as intermediate.
Args
x
:Array
- JAX array to convert.
Returns
torch.Tensor
- Converted PyTorch tensor.
Expand source code
def jax2pt(x: Array) -> torch.Tensor: ''' Converts a JAX array into a PyTorch tensor using NumPy as intermediate. Args: x (Array): JAX array to convert. Returns: torch.Tensor: Converted PyTorch tensor. ''' return np2pt(jax2np(x))
def pt2jax(x: torch.Tensor) ‑> jax.Array
-
Converts a PyTorch tensor into a JAX array using NumPy as intermediate.
Args
x
:torch.Tensor
- PyTorch tensor to convert.
Returns
Array
- Converted JAX array.
Expand source code
def pt2jax(x: torch.Tensor) -> Array: ''' Converts a PyTorch tensor into a JAX array using NumPy as intermediate. Args: x (torch.Tensor): PyTorch tensor to convert. Returns: Array: Converted JAX array. ''' return np2jax(pt2np(x))