Source code for energnn.graph.utils

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from typing import Any

import jax
import numpy as np


[docs] def to_numpy(a: dict | np.ndarray | jax.Array | tuple | None) -> dict | np.ndarray | None: """ Converts a NumPy array, JAX array, or tuple of values into a NumPy array (dtype float32), or converts the values in a dictionary accordingly. - If `a` is None, returns None. - If `a` is a np.ndarray, jax.Array, jnp.ndarray, or tuple, it is converted to a np.ndarray (float32). - If `a` is a dict with some values being arrays or tuples, only those values are converted; others remain unchanged. - In all other cases, a TypeError is raised. :param a: A np.ndarray, jax.Array, tuple, dict, or None. :returns: Either None, a np.ndarray, or a dict with the same keys and converted np.ndarray values. :raises TypeError: If `a` is not of an expected or supported type. """ if a is None: return None def _to_np(x: Any) -> Any: # On traite np.ndarray, jax.Array et tuple if isinstance(x, (np.ndarray, jax.Array, np.ndarray, tuple)): return np.array(x, dtype=np.dtype("float32")) else: return x if isinstance(a, dict): output: dict[Any, np.array] = {} for key, value in a.items(): output[key] = _to_np(value) # seules les values “ArrayLike” seront converties return output # Cas array-like, tuple et object if isinstance(a, (np.ndarray, jax.Array, np.ndarray, tuple, object)): return _to_np(a) raise TypeError(f"Type {type(a)} non pris en charge par to_numpy")