jnp_to_np

jnp_to_np(x)[source]

Convert JAX arrays or mappings of JAX arrays back to NumPy arrays.

This function handles both individual JAX arrays and dictionaries mapping string keys to JAX arrays. It converts each array to a NumPy array.

Parameters:

x (Array | dict[str, Array] | None) – JAX array or dict of JAX arrays to convert. If None, returns None.

Returns:

NumPy array or dict of NumPy arrays matching the input structure, or None if the input is None.

Return type:

ndarray | dict[str, ndarray] | None