np_to_jnp

np_to_jnp(x, device=None, dtype='float32')[source]

Convert NumPy arrays or dictionary of NumPy arrays to JAX arrays.

This function handles both individual NumPy arrays and dictionaries mapping string keys to NumPy arrays. It converts each array to a JAX array with the specified data type and places it on the given device if provided.

Parameters:
  • x (ndarray | dict[str, ndarray] | None) – NumPy array or dict of NumPy arrays to convert. If None, returns None.

  • device (Device | None) – JAX device to place the arrays on. If None, the default JAX device is used.

  • dtype (str) – Data type for the JAX arrays (e.g., ‘float32’).

Returns:

JAX array or dict of JAX arrays matching the structure of the input, or None if the input is None.

Return type:

Array | dict[str, Array] | None