Source code for energnn.graph.jax.hyper_edge_set

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from __future__ import annotations

from typing import Any, Sequence

import jax
from jax import Device
from jax.tree_util import register_pytree_node_class

from energnn.graph.hyper_edge_set import HyperEdgeSet
from energnn.graph.jax.utils import jnp_to_np, np_to_jnp

FEATURE_ARRAY = "feature_array"
FEATURE_NAMES = "feature_names"
PORT_DICT = "port_dict"
NON_FICTITIOUS = "non_fictitious"


[docs] @register_pytree_node_class class JaxHyperEdgeSet(dict): """ jax implementation of a collection of hyper-edges of the same class, optionally batched. Internally this is just a dict storing four entries. :param port_dict: Dictionary that maps port names to address values. :param feature_array: Array that contains all hyper-edge features. :param feature_names: Dictionary from feature names to index in `feature_array`. :param non_fictitious: Binary mask filled with ones for non-fictitious objects. """ def __init__( self, *, port_dict: dict[str, jax.Array] | None, feature_array: jax.Array | None, feature_names: dict[str, jax.Array] | None, non_fictitious: jax.Array, ): super().__init__() self[PORT_DICT] = port_dict self[FEATURE_ARRAY] = feature_array self[FEATURE_NAMES] = feature_names self[NON_FICTITIOUS] = non_fictitious
[docs] def tree_flatten(self) -> tuple: """ Flattens a PyTree, required for JAX compatibility. :returns: a tuple of values and keys """ children = self.values() aux = self.keys() return children, aux
[docs] @classmethod def tree_unflatten(cls, aux_data: Sequence[str], children: Sequence[Any]) -> JaxHyperEdgeSet: """ Unflattens a PyTree, required for JAX compatibility. This method reconstructs an instance of the class from a flattened PyTree structure. :param aux_data: Tuple of keys originally returned by tree_flatten. :param children: Sequence of values originally returned by tree_flatten. :return: Reconstructed instance of the class (`JaxHyperEdgeSet`). :raises KeyError: If the expected keys are missing in the zipped dictionary. """ d = dict(zip(aux_data, children)) return cls( port_dict=d[PORT_DICT], feature_array=d[FEATURE_ARRAY], feature_names=d[FEATURE_NAMES], non_fictitious=d[NON_FICTITIOUS], )
@property def feature_names(self) -> dict[str, jax.Array] | None: return self[FEATURE_NAMES] @property def port_dict(self) -> dict[str, jax.Array] | None: return self[PORT_DICT] @property def non_fictitious(self) -> jax.Array: return self[NON_FICTITIOUS] @property def feature_array(self) -> jax.Array | None: return self[FEATURE_ARRAY] @feature_array.setter def feature_array(self, value: jax.Array) -> None: self[FEATURE_ARRAY] = value @property def feature_flat_array(self) -> jax.Array | None: """ Returns a flat array by concatenating all features together. - Single mode: shape `(num_objects * num_features,)` - Batch mode: shape `(batch_size, num_objects * num_features)`. """ if self.feature_names is not None: if len(self.feature_array.shape) == 2: return self.feature_array.reshape([-1], order="F") elif len(self.feature_array.shape) == 3: n_batch = self.feature_array.shape[0] return self.feature_array.reshape([n_batch, -1], order="F") else: raise ValueError("Feature array should be of order 2 (single) or 3 (batch).") else: return None
[docs] @classmethod def from_numpy_hyper_edge_set( cls, hyper_edge_set: HyperEdgeSet, device: Device | None = None, dtype: str = "float32" ) -> JaxHyperEdgeSet: """ Convert a classical numpy hyper-edge set to a jax.numpy format for GNN processing. This method transforms all array-like attributes of a ``HyperEdgeSet`` object into their JAX equivalents, allowing efficient use with JAX transformations and accelerators. :param hyper_edge_set: A hyper-edge set object containing NumPy arrays to convert. :param device: Optional JAX device (e.g., CPU, GPU) to place the converted arrays on. If None, JAX uses the default device. :param dtype: Desired floating-point precision for converted arrays (e.g., "float32", "float64"). :return: A JAX-compatible version of the hyper-edge set, ready for use in GNN pipelines. """ port_dict = np_to_jnp(hyper_edge_set.port_dict, device=device, dtype=dtype) feature_array = np_to_jnp(hyper_edge_set.feature_array, device=device, dtype=dtype) feature_names = np_to_jnp(hyper_edge_set.feature_names, device=device, dtype=dtype) non_fictitious = np_to_jnp(hyper_edge_set.non_fictitious, device=device, dtype=dtype) return cls( port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious )
[docs] def to_numpy_hyper_edge_set(self) -> HyperEdgeSet: """ Convert a jax.numpy hyper-edge set for GNN processing to a classical numpy hyper-edge set. This method transforms the internal JAX arrays of the hyper-edge set back into standard NumPy arrays, enabling compatibility with non-JAX components. :return: A classical ``HyperEdgeSet`` object with NumPy arrays. """ port_dict = jnp_to_np(self.port_dict) feature_array = jnp_to_np(self.feature_array) feature_names = jnp_to_np(self.feature_names) non_fictitious = jnp_to_np(self.non_fictitious) return HyperEdgeSet( port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious )