to_numpy

to_numpy(a)[source]

Converts a NumPy array, JAX array, or tuple of values into a NumPy array (dtype float32), or converts the values in a dictionary accordingly.

  • If a is None, returns None.

  • If a is a np.ndarray, jax.Array, jnp.ndarray, or tuple, it is converted to a np.ndarray (float32).

  • If a is a dict with some values being arrays or tuples, only those values are converted; others remain unchanged.

  • In all other cases, a TypeError is raised.

Parameters:

a (dict | ndarray | Array | tuple | None) – A np.ndarray, jax.Array, tuple, dict, or None.

Returns:

Either None, a np.ndarray, or a dict with the same keys and converted np.ndarray values.

Raises:

TypeError – If a is not of an expected or supported type.

Return type:

dict | ndarray | None