Module lib.param_utils.check_params_equal

Expand source code
from typing import Any

from jax import Array
import jax.numpy as jnp
import numpy as np

def check_params_equal(t1: Any, t2: Any) -> bool:
    '''
    Recursively checks the equality of two objects. 

    If both objects are NumPy arrays, `np.array_equal()` is used for comparison.
    If both objects are JAX arrays, `jnp.array_equal()` is used for comparison.
    If both objects are namedtuples, the function is called recursively on each corresponding field.
    Otherwise, the standard equality operator `==` is used.
    '''
    if t1.__class__ != t2.__class__:
        return False
    if isinstance(t1, np.ndarray):
        return np.array_equal(t1, t2)
    if isinstance(t1, Array):
        return bool(jnp.array_equal(t1, t2))
    if isinstance(t1, tuple) and hasattr(t1, '_fields'):  # check if t1 is namedtuple
        return all(check_params_equal(t1[i], t2[i]) for i in range(len(t1)))
    return t1 == t2

Functions

def check_params_equal(t1: Any, t2: Any) ‑> bool

Recursively checks the equality of two objects.

If both objects are NumPy arrays, np.array_equal() is used for comparison. If both objects are JAX arrays, jnp.array_equal() is used for comparison. If both objects are namedtuples, the function is called recursively on each corresponding field. Otherwise, the standard equality operator == is used.

Expand source code
def check_params_equal(t1: Any, t2: Any) -> bool:
    '''
    Recursively checks the equality of two objects. 

    If both objects are NumPy arrays, `np.array_equal()` is used for comparison.
    If both objects are JAX arrays, `jnp.array_equal()` is used for comparison.
    If both objects are namedtuples, the function is called recursively on each corresponding field.
    Otherwise, the standard equality operator `==` is used.
    '''
    if t1.__class__ != t2.__class__:
        return False
    if isinstance(t1, np.ndarray):
        return np.array_equal(t1, t2)
    if isinstance(t1, Array):
        return bool(jnp.array_equal(t1, t2))
    if isinstance(t1, tuple) and hasattr(t1, '_fields'):  # check if t1 is namedtuple
        return all(check_params_equal(t1[i], t2[i]) for i in range(len(t1)))
    return t1 == t2