# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from typing import Any, Sequence
import jax
import jax.numpy as jnp
import pandas as pd
from jax import Device
from jax.tree_util import register_pytree_node_class
from energnn.graph.hyper_edge_set import HyperEdgeSet
from energnn.graph.jax.utils import jnp_to_np, np_to_jnp
from energnn.graph.utils import to_numpy
FEATURE_ARRAY = "feature_array"
FEATURE_NAMES = "feature_names"
PORT_DICT = "port_dict"
NON_FICTITIOUS = "non_fictitious"
[docs]
@register_pytree_node_class
class JaxHyperEdgeSet(dict):
"""
jax implementation of a collection of hyper-edges of the same class, optionally batched.
Internally this is just a dict storing four entries.
:param port_dict: Dictionary that maps port names to address values.
:param feature_array: Array that contains all hyper-edge features.
:param feature_names: Dictionary from feature names to index in `feature_array`.
:param non_fictitious: Binary mask filled with ones for non-fictitious objects.
"""
def __init__(
self,
*,
port_dict: dict[str, jax.Array] | None,
feature_array: jax.Array | None,
feature_names: dict[str, jax.Array] | None,
non_fictitious: jax.Array,
) -> None:
super().__init__()
self[PORT_DICT] = port_dict
self[FEATURE_ARRAY] = feature_array
self[FEATURE_NAMES] = feature_names
self[NON_FICTITIOUS] = non_fictitious
@classmethod
def from_dict(
cls,
*,
port_dict: dict[str, Any] | None = None,
feature_dict: dict[str, Any] | None = None,
) -> JaxHyperEdgeSet:
"""
Build a JaxHyperEdgeSet from raw dicts of ports and features.
Both inputs may be None, in which case the corresponding properties
are set to None and only `non_fictitious` of length zero is created.
:param port_dict: Dictionary of ports, each key corresponds to a port name and to the values are the
corresponding addresses for each object stored into an array.
:param feature_dict: Dictionary of features, each key corresponds to a feature name and to the values are the
corresponding features for each object stored into an array.
:returns: A properly structured `JaxHyperEdgeSet` instance.
:raises ValueError: If ports or features contain NaNs or if shapes mismatch.
"""
# Convert inputs to pure numpy arrays / dicts
port_dict = check_dict_or_none_jax(to_numpy(port_dict))
feature_dict = check_dict_or_none_jax(to_numpy(feature_dict))
check_valid_ports_jax(port_dict)
check_no_nan_jax(port_dict=port_dict, feature_dict=feature_dict)
# Build feature_names and feature_array
if feature_dict is not None:
feature_names = {name: idx for idx, name in enumerate(sorted(feature_dict))}
feature_array = dict2array_jax(feature_dict)
else:
feature_names, feature_array = None, None
# Build a non-fictitious mask.
shape = build_hyper_edge_set_shape_jax(port_dict=port_dict, feature_dict=feature_dict)
non_fictitious = jnp.ones(int(shape))
return cls(
port_dict=port_dict,
feature_array=feature_array,
feature_names=feature_names,
non_fictitious=non_fictitious,
)
def __str__(self) -> str:
"""
Render the JaxHyperEdgeSet as a pandas DataFrame string.
If `is_single`, uses a single-level index:
object_id
If `is_batch`, uses two-level index:
batch_id, object_id
:returns:
String representation of a `pandas.DataFrame`.
:raises ValueError:
If the internal array has unexpected dimensions.
"""
if self.is_single:
index = pd.MultiIndex.from_product([range(self.n_obj)], names=["object_id"])
elif self.is_batch:
index = pd.MultiIndex.from_product(
[range(self.n_batch), range(self.n_obj)],
names=["batch_id", "object_id"],
)
else:
raise ValueError("JaxHyperEdgeSet is neither single nor batched.")
d = {}
if self.port_names is not None:
for k, v in sorted(self.port_dict.items()):
d[("ports", k)] = v.reshape([-1])
if self.feature_names is not None:
for k, v in sorted(self.feature_dict.items()):
d[("features", k)] = v.reshape([-1])
return pd.DataFrame(d, index=index).__str__()
[docs]
def tree_flatten(self) -> tuple:
"""
Flattens a PyTree, required for JAX compatibility.
:returns: a tuple of values and keys
"""
children = self.values()
aux = self.keys()
return children, aux
[docs]
@classmethod
def tree_unflatten(cls, aux_data: Sequence[str], children: Sequence[Any]) -> JaxHyperEdgeSet:
"""
Unflattens a PyTree, required for JAX compatibility.
This method reconstructs an instance of the class from a flattened PyTree structure.
:param aux_data: Tuple of keys originally returned by tree_flatten.
:param children: Sequence of values originally returned by tree_flatten.
:return: Reconstructed instance of the class (`JaxHyperEdgeSet`).
:raises KeyError: If the expected keys are missing in the zipped dictionary.
"""
d = dict(zip(aux_data, children))
return cls(
port_dict=d[PORT_DICT],
feature_array=d[FEATURE_ARRAY],
feature_names=d[FEATURE_NAMES],
non_fictitious=d[NON_FICTITIOUS],
)
@property
def array(self) -> jax.Array:
"""
Concatenate (features, ports) along the last axis.
:returns:
Combined array of shape
- single: `(n_obj, n_feats + n_ports)`
- batch: `(batch, n_obj, n_feats + n_ports)`
"""
array = []
if self.feature_array is not None:
array.append(self.feature_array)
if self.port_array is not None:
array.append(self.port_array)
return jnp.concatenate(array, axis=-1)
@property
def is_batch(self) -> bool:
"""
True if `array` is 3-D: `(batch, n_obj, features+ports)`.
"""
return len(self.array.shape) == 3
@property
def is_single(self) -> bool:
"""
True if `array` is 2-D: `(n_obj, features+ports)`.
"""
return len(self.array.shape) == 2
@property
def n_obj(self) -> int:
"""
Number of hyper-edges (objects) per instance.
"""
if self.is_single:
return int(self.array.shape[0])
elif self.is_batch:
return int(self.array.shape[1])
else:
raise ValueError("JaxHyperEdgeSet is neither single nor batched.")
@property
def n_batch(self) -> int:
"""
Number of batches. Only valid if `is_batch` is True.
:raises ValueError: If not a batch.
"""
if self.is_batch:
return int(self.array.shape[0])
else:
raise ValueError("JaxHyperEdgeSet is not batched.")
@property
def feature_names(self) -> dict[str, jax.Array] | None:
return self[FEATURE_NAMES]
@property
def port_array(self) -> jax.Array | None:
"""
Returns the stacked array of ports, of shape `(n_obj, n_ports)` or `(batch, n_obj, n_ports)`.
"""
if self.port_dict is None:
return None
return dict2array_jax(self.port_dict)
@property
def port_names(self) -> dict[str, jax.Array] | None:
"""
Maps a port name to a column index in `port_array`.
"""
if self.port_dict is None:
return None
return {k: jnp.array(idx) for idx, k in enumerate(sorted(self.port_dict.keys()))}
@property
def port_dict(self) -> dict[str, jax.Array] | None:
return self[PORT_DICT]
@port_dict.setter
def port_dict(self, value: dict[str, jax.Array] | None) -> None:
self[PORT_DICT] = value
@property
def non_fictitious(self) -> jax.Array:
return self[NON_FICTITIOUS]
@non_fictitious.setter
def non_fictitious(self, value: jax.Array) -> None:
self[NON_FICTITIOUS] = value
@property
def feature_array(self) -> jax.Array | None:
return self[FEATURE_ARRAY]
@feature_array.setter
def feature_array(self, value: jax.Array) -> None:
self[FEATURE_ARRAY] = value
@property
def feature_dict(self) -> dict[str, jax.Array] | None:
"""
Unstack `feature_array` into a dict: feature_name --> array.
:returns: Dict of shape `(n_obj,)` or `(batch, n_obj)` per feature.
"""
if not self.feature_names:
return None
result = dict()
for k, v in self.feature_names.items():
# The last axis holds features
if self.is_batch:
result[k] = self.feature_array[..., jnp.array(v[0], int)]
else:
result[k] = self.feature_array[..., jnp.array(v, int)]
return result
@property
def feature_flat_array(self) -> jax.Array | None:
"""
Returns a flat array by concatenating all features together.
- Single mode: shape `(num_objects * num_features,)`
- Batch mode: shape `(batch_size, num_objects * num_features)`.
"""
if self.feature_names is not None:
if len(self.feature_array.shape) == 2:
return self.feature_array.reshape([-1], order="F")
elif len(self.feature_array.shape) == 3:
n_batch = self.feature_array.shape[0]
return self.feature_array.reshape([n_batch, -1], order="F")
else:
raise ValueError("Feature array should be of order 2 (single) or 3 (batch).")
else:
return None
@feature_flat_array.setter
def feature_flat_array(self, array: jax.Array) -> None:
"""
Update the feature array from a flat Fortran-ordered array.
:param array: Must match the shape of current `.feature_flat_array`.
:raises ValueError: If shapes mismatch.
"""
flat = self.feature_flat_array
if flat is None or flat.shape != array.shape:
raise ValueError("Shape mismatch for feature_flat_array setter.")
if self.feature_names is not None:
if self.is_single:
self.feature_array = array.reshape([self.n_obj, -1], order="F")
elif self.is_batch:
self.feature_array = array.reshape([self.n_batch, self.n_obj, -1], order="F")
[docs]
@classmethod
def from_numpy_hyper_edge_set(
cls, hyper_edge_set: HyperEdgeSet, device: Device | None = None, dtype: str = "float32"
) -> JaxHyperEdgeSet:
"""
Convert a classical numpy hyper-edge set to a jax.numpy format for GNN processing.
This method transforms all array-like attributes of a ``HyperEdgeSet`` object into
their JAX equivalents, allowing efficient use with JAX transformations and accelerators.
:param hyper_edge_set: A hyper-edge set object containing NumPy arrays to convert.
:param device: Optional JAX device (e.g., CPU, GPU) to place the converted arrays on.
If None, JAX uses the default device.
:param dtype: Desired floating-point precision for converted arrays (e.g., "float32", "float64").
:return: A JAX-compatible version of the hyper-edge set, ready for use in GNN pipelines.
"""
port_dict = np_to_jnp(hyper_edge_set.port_dict, device=device, dtype=dtype)
feature_array = np_to_jnp(hyper_edge_set.feature_array, device=device, dtype=dtype)
feature_names = np_to_jnp(hyper_edge_set.feature_names, device=device, dtype=dtype)
non_fictitious = np_to_jnp(hyper_edge_set.non_fictitious, device=device, dtype=dtype)
return cls(
port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious
)
[docs]
def to_numpy_hyper_edge_set(self) -> HyperEdgeSet:
"""
Convert a jax.numpy hyper-edge set for GNN processing to a classical numpy hyper-edge set.
This method transforms the internal JAX arrays of the hyper-edge set back into standard
NumPy arrays, enabling compatibility with non-JAX components.
:return: A classical ``HyperEdgeSet`` object with NumPy arrays.
"""
port_dict = jnp_to_np(self.port_dict)
feature_array = jnp_to_np(self.feature_array)
feature_names = jnp_to_np(self.feature_names)
non_fictitious = jnp_to_np(self.non_fictitious)
return HyperEdgeSet(
port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious
)
def pad(self, target_shape: jax.Array | int) -> None:
"""
Pad a *single* JaxHyperEdgeSet with a series of zeros for features and max-int for ports
so that shapes match the `target_shape`.
:param target_shape: Desired n_obj after padding; must be ≥ current n_obj.
:raises ValueError: If called on a batch or if target_shape < current n_obj.
"""
if not self.is_single:
raise ValueError("JaxHyperEdgeSet is batched, impossible to pad.")
old_n_obj = self.n_obj
if old_n_obj > target_shape:
raise ValueError("Provided target_shape is smaller than current shape, padding is impossible! ")
# Pad features
if self.feature_array is not None:
self.feature_array = jnp.pad(self.feature_array, [(0, int(target_shape) - old_n_obj), (0, 0)])
# Pad ports
if self.port_dict is not None:
for k, v in self.port_dict.items():
self.port_dict[k] = jnp.pad(v, [0, int(target_shape) - old_n_obj])
# Pad fictitious mask
if self.non_fictitious is not None:
self.non_fictitious = jnp.pad(self.non_fictitious, [0, int(target_shape) - old_n_obj])
def unpad(self, target_shape: jax.Array | int) -> None:
"""
Remove all objects beyond the index `target` in a *single* JaxHyperEdgeSet.
:param target_shape: New n_obj; must be ≤ current n_obj.
:raises ValueError: If called on a batch or if target_shape > current n_obj.
"""
if not self.is_single:
raise ValueError("JaxHyperEdgeSet is batched, impossible to unpad.")
if self.n_obj < target_shape:
raise ValueError("Provided target_shape is higher than current shape, unpadding is impossible! ")
# Unpad features
if self.feature_array is not None:
self.feature_array = self.feature_array[: int(target_shape)]
# Unpad ports
if self.port_dict is not None:
for k, v in self.port_dict.items():
self.port_dict[k] = v[: int(target_shape)]
# Unpad fictitious mask
if self.non_fictitious is not None:
self.non_fictitious = self.non_fictitious[: int(target_shape)]
def offset_addresses(self, offset: jax.Array | int) -> None:
"""Adds an offset on all addresses. Should only be used before graph concatenation.
:param offset: Scalar or array to add to each address array.
"""
self.port_dict = {k: a + jnp.array(offset) for k, a in self.port_dict.items()}
def collate_hyper_edge_sets_jax(hyper_edge_set_list: list[JaxHyperEdgeSet]) -> JaxHyperEdgeSet:
"""
Collate a list of JaxHyperEdgeSet into a single batched JaxHyperEdgeSet.
Each JaxHyperEdgeSet in the input list is assumed to have the same feature and port schema.
This function stacks the per-edge attributes along the 0-th axis.
:param hyper_edge_set_list: Sequence of JaxHyperEdgeSet objects to batch together. Must be non-empty.
:return: A single batched JaxHyperEdgeSet.
:raises IndexError: Raised if `hyper_edge_set_list` is empty.
:raises ValueError: Raised if not all JaxHyperEdgeSet share the same keys in port_names or feature_names.
"""
if not hyper_edge_set_list:
raise IndexError("collate_hyper_edge_sets_jax requires at least one JaxHyperEdgeSet to collate.")
first_hyper_edge_set = hyper_edge_set_list[0]
# Check the consistency of keys
for e in hyper_edge_set_list[1:]:
_check_keys_consistency_jax(first_hyper_edge_set, e)
# Collate feature arrays
if first_hyper_edge_set.feature_array is not None:
feature_array = jnp.stack([e.feature_array for e in hyper_edge_set_list], axis=0)
else:
feature_array = None
# Collate feature names
if first_hyper_edge_set.feature_names is not None:
feature_names = {
k: jnp.stack([e.feature_names[k] for e in hyper_edge_set_list]) for k in first_hyper_edge_set.feature_names
}
else:
feature_names = None
# Collate port dicts
if first_hyper_edge_set.port_dict is not None:
port_dict = {k: jnp.stack([e.port_dict[k] for e in hyper_edge_set_list]) for k in first_hyper_edge_set.port_dict}
else:
port_dict = None
# Collate non-fictitious masks
if first_hyper_edge_set.non_fictitious is not None:
non_fictitious = jnp.stack([e.non_fictitious for e in hyper_edge_set_list])
else:
non_fictitious = None
return JaxHyperEdgeSet(
port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious
)
def separate_hyper_edge_sets_jax(hyper_edge_set_batch: JaxHyperEdgeSet) -> list[JaxHyperEdgeSet]:
"""
Separate a batched JaxHyperEdgeSet into its constituent JaxHyperEdgeSet instances.
The input JaxHyperEdgeSet must have been created by :py:func:`collate_hyper_edge_sets_jax` or otherwise
its property "array" must return a 3D array.
:param hyper_edge_set_batch: The batched JaxHyperEdgeSet to unstack.
:return: List of JaxHyperEdgeSet instances, each corresponding to one batch element.
:raises ValueError: If `hyper_edge_set_batch.is_batch` is False.
"""
if not hyper_edge_set_batch.is_batch:
raise ValueError("Input is not a batch, impossible to separate.")
if hyper_edge_set_batch.feature_array is not None:
feature_array_list = jnp.unstack(hyper_edge_set_batch.feature_array, axis=0)
else:
feature_array_list = [None] * hyper_edge_set_batch.n_batch
if hyper_edge_set_batch.feature_names is not None:
a = {k: jnp.unstack(hyper_edge_set_batch.feature_names[k]) for k in hyper_edge_set_batch.feature_names}
feature_names_list = [dict(zip(a, t)) for t in zip(*a.values())]
else:
feature_names_list = [None] * hyper_edge_set_batch.n_batch
if hyper_edge_set_batch.port_dict is not None:
a = {k: jnp.unstack(hyper_edge_set_batch.port_dict[k]) for k in hyper_edge_set_batch.port_dict}
port_dict_list = [dict(zip(a, t)) for t in zip(*a.values())]
else:
port_dict_list = [None] * hyper_edge_set_batch.n_batch
if hyper_edge_set_batch.non_fictitious is not None:
non_fictitious_list = jnp.unstack(hyper_edge_set_batch.non_fictitious, axis=0)
else:
non_fictitious_list = [None] * hyper_edge_set_batch.n_batch
hyper_edge_set_list = []
for fa, fn, ad, nf in zip(feature_array_list, feature_names_list, port_dict_list, non_fictitious_list):
hyper_edge_set = JaxHyperEdgeSet(port_dict=ad, feature_array=fa, feature_names=fn, non_fictitious=nf)
hyper_edge_set_list.append(hyper_edge_set)
return hyper_edge_set_list
def concatenate_hyper_edge_sets_jax(hyper_edge_set_list: list[JaxHyperEdgeSet]) -> JaxHyperEdgeSet:
"""
Concatenate several single JaxHyperEdgeSet into one single JaxHyperEdgeSet.
Unlike :py:func:`collate_hyper_edge_sets`, this does *not* create a batch dimension,
but simply stacks objects end-to-end.
:param hyper_edge_set_list: List of single (non-batched) JaxHyperEdgeSet
:returns: One JaxHyperEdgeSet with n_obj = sum of all inputs’ n_obj
"""
port_dict = {
k: jnp.concatenate([hes.port_dict[k] for hes in hyper_edge_set_list]) for k in hyper_edge_set_list[0].port_dict
}
feature_array = jnp.concatenate([hes.feature_array for hes in hyper_edge_set_list], axis=0)
feature_names = hyper_edge_set_list[0].feature_names
non_fictitious = jnp.concatenate([hes.non_fictitious for hes in hyper_edge_set_list])
return JaxHyperEdgeSet(
port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious
)
def check_dict_shape_jax(*, d: dict[str, jax.Array] | None, n_objects: int | None) -> int | None:
"""
Ensure all arrays in a dictionary have the same size on their last axis.
If `n_objects` is not provided, it is inferred from the first array’s last dimension.
Otherwise, every array’s last dimension must match the given `n_objects`.
:param d: Mapping from feature/port name to `jax.numpy` array
where each array’s last axis is object-indexed.
:param n_objects: Optional expected size of the last axis; if None, will be inferred.
:return: The validated or inferred `n_objects`.
:raises ValueError: If any array’s last dimension does not match `n_objects`.
"""
if d is not None:
if n_objects is None:
item: jnp.ndarray = next(iter(d.values()))
n_objects = item.shape[-1]
for name, arr in d.items():
if arr.shape[-1] != n_objects:
raise ValueError(f"Array for key '{name}' has last dimension {arr.shape[-1]}, expected {n_objects}.")
return n_objects
def build_hyper_edge_set_shape_jax(
*,
port_dict: dict[str, jax.Array] | None,
feature_dict: dict[str, jax.Array] | None,
) -> jax.Array:
"""
Builds a jax.numpy array representing the number of hyper-edges.
Validate that `port_dict` and `feature_dict` have consistent sizes
on their last dimensions and return a scalar jax array containing that count.
:param port_dict: Mapping from port names to jax.numpy arrays, or None.
:param feature_dict: Mapping of feature names to jax.numpy arrays, or None.
:return: A scalar jax.numpy array of dtype float32 with the number of objects.
:raises ValueError: If both inputs are None, or if their shapes conflict.
"""
if port_dict is None and feature_dict is None:
raise ValueError("At least one of port_dict or feature_dict must be provided.")
n_objects = check_dict_shape_jax(d=port_dict, n_objects=None)
n_objects = check_dict_shape_jax(d=feature_dict, n_objects=n_objects)
return jnp.array(n_objects, dtype=jnp.dtype("float32"))
def dict2array_jax(features_dict: dict[str, jax.Array] | None) -> jax.Array | None:
"""
Stack a dictionary of jax arrays into a single jax array along the last axis.
The jax arrays are stacked in alphabetical order of their dictionary keys.
:param features_dict: Mapping from a feature name to a `jax.numpy` array, or None.
:return: A stacked jax array with an added last dimension for features, or None.
"""
if features_dict is None:
return None
return jnp.stack([features_dict[k] for k in sorted(features_dict)], axis=-1)
def check_dict_or_none_jax(_input: dict | jnp.ndarray | None) -> dict | None:
"""
Validate that the input is either a dict or None.
:param _input: Object to validate
:return: the input if it was a dict or None
:raises ValueError: if `_input` is neither dict nor None
"""
if isinstance(_input, dict):
return _input
if _input is None:
return None
raise ValueError(f"Expected dict or None, got {type(_input)}")
def check_no_nan_jax(
*,
port_dict: dict[str, jax.Array] | None,
feature_dict: dict[str, jax.Array] | None,
) -> None:
"""
Ensure there are no NaN values in port or feature arrays.
:param port_dict: Mapping from port names to jax arrays, or None.
:param feature_dict: Mapping of feature names to jax arrays, or None.
:raises ValueError: If any jax array contains NaN.
"""
for name, arr in (port_dict or {}).items():
if jnp.any(jnp.isnan(arr)):
raise ValueError(f"NaN detected in port array for key '{name}'.")
for name, arr in (feature_dict or {}).items():
if jnp.any(jnp.isnan(arr)):
raise ValueError(f"NaN detected in feature array for key '{name}'.")
def check_valid_ports_jax(port_dict: dict[str, jax.Array] | None) -> None:
"""
Ensure that ports map only to integer-valued addresses.
:param port_dict: Mapping from port names to jax arrays, or None.
:raises ValueError: If any port array has entries that are not integer.
"""
for name, arr in (port_dict or {}).items():
if not jnp.allclose(arr, jnp.int32(arr)):
raise ValueError(f"Non-integer values detected in port array for key '{name}'.")
def _check_keys_consistency_jax(hes_1: JaxHyperEdgeSet, hes_2: JaxHyperEdgeSet):
if (hes_1.port_names is None) != (hes_2.port_names is None):
raise ValueError("Mismatch in presence of port_names among hyper-edge sets.")
if (hes_1.feature_names is None) != (hes_2.feature_names is None):
raise ValueError("Mismatch in presence of feature_names among hyper-edge sets.")
if hes_1.port_names and hes_1.port_names.keys() != hes_2.port_names.keys():
raise ValueError("Inconsistent port_names keys among hyper-edge sets.")
if hes_1.feature_names and hes_1.feature_names.keys() != hes_2.feature_names.keys():
raise ValueError("Inconsistent feature_names keys among hyper-edge sets.")