Source code for energnn.graph.jax.shape

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from __future__ import annotations

import jax
from jax import Device
from jax.tree_util import register_pytree_node_class

from energnn.graph.jax.utils import jnp_to_np, np_to_jnp
from energnn.graph.shape import GraphShape

HYPER_EDGE_SETS = "hyper_edge_sets"
ADDRESSES = "addresses"


[docs] @register_pytree_node_class class JaxGraphShape(dict): """ PyTree container for storing the number of objects in each class, and addresses in the graph. This class inherits from `dict` and stores two keys: :param hyper_edge_sets: Dictionary of that contains the number of objects for each class. :param addresses: Number of addresses in the graph. The PyTree methods ``tree_flatten`` and ``tree_unflatten`` make this object compatible with JAX transformations (jit, vmap, etc.). """ def __init__(self, *, hyper_edge_sets: dict[str, jax.Array], addresses: jax.Array): super().__init__() self[HYPER_EDGE_SETS] = hyper_edge_sets self[ADDRESSES] = addresses
[docs] def tree_flatten(self): """ Flatten the JaxGraphShape for JAX PyTree compatibility. :returns: Flat children and auxiliary data (the keys order). """ children = self.values() aux = self.keys() return children, aux
[docs] @classmethod def tree_unflatten(cls, aux_data, children) -> JaxGraphShape: """ Reconstruct a JaxGraphShape from flattened data, required for JAX compatibility. :param aux_data: Sequence of keys matching the order of the children. :param children: Sequence of array values. :return: A reconstructed JaxGraphShape instance. """ d = dict(zip(aux_data, children)) return cls(hyper_edge_sets=d[HYPER_EDGE_SETS], addresses=d[ADDRESSES])
@property def hyper_edge_sets(self) -> dict[str, jax.Array]: """Dictionary of edge shapes.""" return self[HYPER_EDGE_SETS] @property def addresses(self) -> jax.Array: """Number of addresses in the graph.""" return self[ADDRESSES]
[docs] @classmethod def from_numpy_shape(cls, shape: GraphShape, device: Device | None = None, dtype: str = "float32") -> JaxGraphShape: """ Convert a classical numpy shape to a jax.numpy format for GNN processing. This method transforms all array-like attributes of a ``GraphShape`` object into their JAX equivalents, allowing efficient use with JAX transformations and accelerators. :param shape: A shape object containing NumPy arrays to convert. :param device: Optional JAX device (e.g., CPU, GPU) to place the converted arrays on. If None, JAX uses the default device. :param dtype: Desired floating-point precision for converted arrays (e.g., "float32", "float64"). :return: A JAX-compatible version of the shape, ready for use in GNN pipelines. """ hyper_edge_sets = np_to_jnp(shape.hyper_edge_sets, device=device, dtype=dtype) addresses = np_to_jnp(shape.addresses, device=device, dtype=dtype) return cls(hyper_edge_sets=hyper_edge_sets, addresses=addresses)
[docs] def to_numpy_shape(self) -> GraphShape: """ Convert a jax.numpy shape for GNN processing to a classical numpy shape. This method transforms the internal JAX arrays of the shape back into standard NumPy arrays, enabling compatibility with non-JAX components. :return: A classical ``GraphShape`` object with NumPy arrays. """ hyper_edge_sets = jnp_to_np(self.hyper_edge_sets) addresses = jnp_to_np(self.addresses) return GraphShape(hyper_edge_sets=hyper_edge_sets, addresses=addresses)