Source code for energnn.graph.jax.graph

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from __future__ import annotations

import pickle as pkl

import jax
import jax.numpy as jnp
from jax import Device
from jax.tree_util import register_pytree_node_class

from energnn.graph.graph import Graph
from energnn.graph.jax.hyper_edge_set import (
    JaxHyperEdgeSet,
    collate_hyper_edge_sets_jax,
    concatenate_hyper_edge_sets_jax,
    separate_hyper_edge_sets_jax,
)
from energnn.graph.jax.shape import JaxGraphShape, collate_shapes_jax, separate_shapes_jax, sum_shapes_jax
from energnn.graph.jax.utils import jnp_to_np, np_to_jnp

HYPER_EDGE_SETS = "hyper_edge_sets"
TRUE_SHAPE = "true_shape"
CURRENT_SHAPE = "current_shape"
NON_FICTITIOUS_ADDRESSES = "non_fictitious_addresses"


[docs] @register_pytree_node_class class JaxGraph(dict): """ Jax implementation of Hyper Heterogeneous Multi Graph (H2MG). Stores hyper-edge sets, shapes, and address masks for single or batched graphs. :param hyper_edge_sets: Dictionary of hyper-edge sets contained in the graph. :param true_shape: True shape of the graph, not altered by padding. :param current_shape: Current shape of the graph, consistent with padding. :param non_fictitious_addresses: Mask filled with ones for real addresses, and zeros otherwise. """ def __init__( self, *, hyper_edge_sets: dict[str, JaxHyperEdgeSet], true_shape: JaxGraphShape, current_shape: JaxGraphShape, non_fictitious_addresses: jax.Array, ) -> None: super().__init__() self[HYPER_EDGE_SETS] = hyper_edge_sets self[TRUE_SHAPE] = true_shape self[CURRENT_SHAPE] = current_shape self[NON_FICTITIOUS_ADDRESSES] = non_fictitious_addresses @classmethod def from_dict(cls, *, hyper_edge_set_dict: dict[str, JaxHyperEdgeSet], n_addresses: jax.Array) -> JaxGraph: """ Builds a graph from a dictionary of :class:`energnn.graph.JaxHyperEdgeSet` and a number of addresses. :param hyper_edge_set_dict: Dictionary of hyper-edge sets contained in the graph. :param n_addresses: Number of unique addresses that appear in all the hyper-edge sets. :return: Graph that contains both the hyper-edge sets and the registry. """ non_fictitious_addresses = jnp.ones(shape=[n_addresses]) check_hyper_edge_set_dict_type_jax(hyper_edge_set_dict) check_valid_addresses_jax(hyper_edge_set_dict, n_addresses) true_shape = JaxGraphShape.from_dict(hyper_edge_set_dict=hyper_edge_set_dict, non_fictitious=non_fictitious_addresses) current_shape = true_shape return cls( hyper_edge_sets=hyper_edge_set_dict, true_shape=true_shape, current_shape=current_shape, non_fictitious_addresses=non_fictitious_addresses, ) @property def true_shape(self) -> JaxGraphShape: """ True shape of the graph with the real number of objects for each hyper-edge set class as well as the size of the registry stored in a GraphShape object. There is no setter for this property. :return: A graph shape of true sizes. """ return self[TRUE_SHAPE] @property def current_shape(self) -> JaxGraphShape: """ The current shape of the graph taking into accounts fake padding objects. :return: A graph shape of current sizes. """ return self[CURRENT_SHAPE] @current_shape.setter def current_shape(self, value: JaxGraphShape) -> None: """ Sets the current shape of the graph taking into accounts fake padding objects. :param value: A new graph shape. """ self[CURRENT_SHAPE] = value
[docs] def tree_flatten(self): """ Flattens the JaxGraph for JAX PyTree compatibility. :returns: Flat children and auxiliary data (the keys order). """ children = self.values() aux = self.keys() return children, aux
[docs] @classmethod def tree_unflatten(cls, aux_data, children) -> JaxGraph: """ Reconstructs a JaxGraph from flattened data, required for JAX compatibility. :param aux_data: Sequence of keys matching the order of the children. :param children: Sequence of array values. :return: A reconstructed JaxGraph instance. """ d = dict(zip(aux_data, children)) return cls( hyper_edge_sets=d[HYPER_EDGE_SETS], true_shape=d[TRUE_SHAPE], current_shape=d[CURRENT_SHAPE], non_fictitious_addresses=d[NON_FICTITIOUS_ADDRESSES], )
@property def hyper_edge_sets(self) -> dict[str, JaxHyperEdgeSet]: """ Gets the dictionary of edge instances. :return: Dict of hyper-edge set class to JaxHyperEdgeSet. """ return self[HYPER_EDGE_SETS] @hyper_edge_sets.setter def hyper_edge_sets(self, hyper_edge_set_dict: dict[str, JaxHyperEdgeSet]) -> None: """ Sets the dictionary of hyper-edge sets. :param hyper_edge_set_dict: New dictionary of hyper-edge set instances. """ self[HYPER_EDGE_SETS] = hyper_edge_set_dict @property def non_fictitious_addresses(self) -> jax.Array: """ Gets the mask filled with ones for real addresses, and zeros otherwise. :return: Array filled with ones and zeros. """ return self[NON_FICTITIOUS_ADDRESSES] @non_fictitious_addresses.setter def non_fictitious_addresses(self, value: jax.Array): """ Sets the address mask. :param value: Array filled with ones and zeros. """ self[NON_FICTITIOUS_ADDRESSES] = value @property def feature_flat_array(self) -> jax.Array: """ Returns an array that concatenates all hyper-edge set features. :return: Jax array of concatenated features. """ values_list = [] for key, hyper_edge_set in sorted(self.hyper_edge_sets.items()): if hyper_edge_set.feature_flat_array is not None: values_list.append(hyper_edge_set.feature_flat_array) return jnp.concatenate(values_list, axis=-1) @feature_flat_array.setter def feature_flat_array(self, value: jax.Array) -> None: """ Updates the flat array contained in the H2MG. :param value: Flat feature array. :raises ValueError: If shapes do not match the current feature flat array. """ if jnp.any(self.feature_flat_array.shape != value.shape): raise ValueError("Invalid array shape.") i = 0 if self.hyper_edge_sets is not None: for key, hyper_edge_set in sorted(self.hyper_edge_sets.items()): if hyper_edge_set.feature_names is not None: length = jnp.shape(hyper_edge_set.feature_flat_array)[-1] if length > 0: self.hyper_edge_sets[key].feature_flat_array = value[..., i : i + length] # Slice over the last axis i += length else: raise ValueError("This jax graph does not contain any hyper-edge set, and can't be cast as a flat array.")
[docs] @classmethod def from_numpy_graph(cls, graph: Graph, device: Device | None = None, dtype: str = "float32") -> JaxGraph: """ Convert a classical numpy graph to a jax.numpy format for GNN processing. This method transforms all array-like attributes of a ``Graph`` object into their JAX equivalents, allowing efficient use with JAX transformations and accelerators. :param graph: A graph object containing NumPy arrays to convert. :param device: Optional JAX device (e.g., CPU, GPU) to place the converted arrays on. If None, JAX uses the default device. :param dtype: Desired floating-point precision for converted arrays (e.g., "float32", "float64"). :return: A JAX-compatible version of the graph, ready for use in GNN pipelines. """ hyper_edge_sets = { k: JaxHyperEdgeSet.from_numpy_hyper_edge_set(hyper_edge_set, device=device, dtype=dtype) for k, hyper_edge_set in graph.hyper_edge_sets.items() } true_shape = JaxGraphShape.from_numpy_shape(graph.true_shape, device=device, dtype=dtype) current_shape = JaxGraphShape.from_numpy_shape(graph.current_shape, device=device, dtype=dtype) non_fictitious_addresses = np_to_jnp(graph.non_fictitious_addresses, device=device, dtype=dtype) return cls( hyper_edge_sets=hyper_edge_sets, non_fictitious_addresses=non_fictitious_addresses, true_shape=true_shape, current_shape=current_shape, )
[docs] def to_numpy_graph(self) -> Graph: """ Convert a jax.numpy graph for GNN processing to a classical numpy graph. This method transforms the internal JAX arrays of the graph back into standard NumPy arrays, enabling compatibility with non-JAX components. :return: A classical ``Graph`` object with NumPy arrays. """ hyper_edge_sets = {k: hyper_edge_set.to_numpy_hyper_edge_set() for k, hyper_edge_set in self.hyper_edge_sets.items()} true_shape = self.true_shape.to_numpy_shape() current_shape = self.current_shape.to_numpy_shape() non_fictitious_addresses = jnp_to_np(self.non_fictitious_addresses) return Graph( hyper_edge_sets=hyper_edge_sets, non_fictitious_addresses=non_fictitious_addresses, true_shape=true_shape, current_shape=current_shape, )
[docs] def quantiles(self, q_list: list[float] | None = None) -> dict[str, jax.Array]: """Computes quantiles of hyper-edge set features. :param q_list: Percentiles to compute :return: Mapping "hyper_edge_set/feature/percentile" to values. :raises ValueError: If the jax graph is not single or batched and cannot be quantiled. """ if q_list is None: q_list = [0.0, 10.0, 25.0, 50.0, 75.0, 90.0, 100.0] info = {} for object_name, hyper_edge_sets in self.hyper_edge_sets.items(): if hyper_edge_sets.feature_dict is not None: for feature_name, array in hyper_edge_sets.feature_dict.items(): if jnp.size(array) > 0: for q in q_list: if self.is_single: value = jnp.nanpercentile(array, q=q) elif self.is_batch: value = jnp.nanpercentile(array, q=q, axis=1) else: raise ValueError("This graph is not single or batch and cannot be quantiled.") info[f"{object_name}/{feature_name}/{q}th-percentile"] = value return info
def __str__(self) -> str: r = "" for k, v in sorted(self.hyper_edge_sets.items()): r += "{}\n{}\n".format(k, v) return r def to_pickle(self, file_path: str) -> None: """Saves a jax graph as a pickle file. :param file_path: Destination path """ with open(file_path, "wb") as handle: pkl.dump(self, handle, protocol=pkl.HIGHEST_PROTOCOL) @classmethod def from_pickle(cls, *, file_path: str) -> JaxGraph: """Loads a jax graph from a pickle file. :param file_path: Source path. :return: Deserialized Graph. """ with open(file_path, "rb") as handle: graph = pkl.load(handle) return graph @property def is_batch(self) -> bool: """ Determines if the jax graph is batched. :return: True if all hyper-edge sets are batched and if the non-fictitious mask is a 2-D array when defined. """ for k, e in self.hyper_edge_sets.items(): if not e.is_batch: return False if (self.non_fictitious_addresses is not None) and (len(self.non_fictitious_addresses.shape) != 2): return False else: return True @property def is_single(self) -> bool: """ Determines if the graph is single. :return: True if all hyper-edge sets are single and if the non-fictitious mask is a 1-D array when defined. """ for k, e in self.hyper_edge_sets.items(): if not e.is_single: return False if (self.non_fictitious_addresses is not None) and (len(self.non_fictitious_addresses.shape) != 1): return False else: return True def pad(self, target_shape: JaxGraphShape) -> None: """ Pads hyper-edge sets and address mask to match target_shape. :param target_shape: Desired JaxGraphShape with larger dimensions. :raises ValueError: If the jax graph is not single. """ if not self.is_single: raise ValueError("This jax graph is not single and cannot be padded.") for key, hyper_edge_set_shape in target_shape.hyper_edge_sets.items(): self.hyper_edge_sets[key].pad(hyper_edge_set_shape) self.non_fictitious_addresses = jnp.pad( self.non_fictitious_addresses, [0, int(target_shape.addresses) - int(self.current_shape.addresses)] ) self.current_shape = target_shape def unpad(self) -> None: """ Removes padding to restore true_shape. :raises ValueError: If the jax graph is not single. """ for key, hyper_edge_set_shape in self.true_shape.hyper_edge_sets.items(): self.hyper_edge_sets[key].unpad(hyper_edge_set_shape) self.non_fictitious_addresses = self.non_fictitious_addresses[: int(self.true_shape.addresses)] self.current_shape = self.true_shape def count_connected_components(self) -> tuple[int, jax.Array]: """ Counts connected components, and the component id of each address. :return: `(num_components, component_labels)` :raises ValueError: If the graph is not single. """ def _max_propagate(*, graph: JaxGraph, h_: jax.Array) -> jax.Array: """Propagates the max value of addresses through hyper-edges.""" h_new_ = h_ edge_h = {} for edge_key, edge in graph.hyper_edge_sets.items(): edge_h[edge_key] = [] for address_key, address_array in edge.port_dict.items(): edge_h[edge_key].append(h_new_[address_array.astype(int)]) edge_h[edge_key] = jnp.stack(edge_h[edge_key], axis=0) edge_h[edge_key] = jnp.max(edge_h[edge_key], axis=0) for address_key, address_array in edge.port_dict.items(): new_val = jnp.max( jnp.stack([edge_h[edge_key], h_new_[address_array.astype(int)]], axis=0), axis=0, ) h_new_ = h_new_.at[address_array.astype(int)].max(new_val) return h_new_ if not self.is_single: raise ValueError("JaxGraph is not single.") h = jnp.arange(len(self.non_fictitious_addresses)) converged = False while not converged: h_new = _max_propagate(graph=self, h_=h) converged = jnp.all(h_new == h) h = h_new u, indices = jnp.unique(h, return_inverse=True) return len(u), indices def offset_addresses(self, offset: jax.Array | int) -> None: """ Adds an offset on all addresses. Should only be used before graph concatenation. :param offset: Integer or array to add to addresses """ for k, e in self.hyper_edge_sets.items(): e.offset_addresses(offset=offset)
def collate_graphs_jax(graph_list: list[JaxGraph]) -> JaxGraph: """ Collate a list of JaxGraphs into a single JaxGraph with padded shapes. All input jax graphs must share the same `current_shape`. :param graph_list: List of JaxGraph instances to collate. :returns: A new JaxGraph whose - `true_shape` is the batch of all `true_shape's. - `current_shape` is the batch of all `current_shape's (they must be identical). - `hyper_edge_sets` are collated per hyper-edge set class. - `non_fictitious_addresses` stacked along a new batch dimension. :raises ValueError: If `graph_list` is an empty list. :raises AssertionError: If the `current_shape` differs among inputs. """ if not graph_list: raise ValueError("collate_graphs requires at least one JaxGraph.") first_graph = graph_list[0] # Assert that all current shapes are equal current_shape_list = [g.current_shape for g in graph_list] current_shape = first_graph.current_shape for s in current_shape_list: assert s == current_shape current_shape_batch = collate_shapes_jax(current_shape_list) true_shape_list = [g.true_shape for g in graph_list] true_shape_batch = collate_shapes_jax(true_shape_list) hyper_edge_sets_batch = {} for k in first_graph.hyper_edge_sets.keys(): hyper_edge_sets_batch[k] = collate_hyper_edge_sets_jax([g.hyper_edge_sets[k] for g in graph_list]) if first_graph.non_fictitious_addresses is not None: non_fictitious_addresses_batch = jnp.stack([g.non_fictitious_addresses for g in graph_list], axis=0) else: non_fictitious_addresses_batch = None return JaxGraph( hyper_edge_sets=hyper_edge_sets_batch, non_fictitious_addresses=non_fictitious_addresses_batch, true_shape=true_shape_batch, current_shape=current_shape_batch, ) def separate_graphs_jax(graph_batch: JaxGraph) -> list[JaxGraph]: """ Split a batch of collated JaxGraph into a list of single JaxGraphs. It reverses the operation of :py:func:`collate_graphs`. :param graph_batch: A JaxGraph whose `current_shape` and `true_shape` are batched. :returns: List of JaxGraphs, each corresponding to one element in the batch. """ current_shape_list = separate_shapes_jax(graph_batch.current_shape) true_shape_list = separate_shapes_jax(graph_batch.true_shape) n_batch = len(current_shape_list) hyper_edge_set_list_dict = {} for k in graph_batch.hyper_edge_sets.keys(): hyper_edge_set_list_dict[k] = separate_hyper_edge_sets_jax(graph_batch.hyper_edge_sets[k]) if graph_batch.non_fictitious_addresses is not None: non_fictitious_addresses_list = jnp.unstack(graph_batch.non_fictitious_addresses, axis=0) else: non_fictitious_addresses_list = [None] * n_batch hyper_edge_set_dict_list = [ {k: hyper_edge_set_list_dict[k][i] for k in hyper_edge_set_list_dict.keys()} for i in range(n_batch) ] graph_list = [] for e, n, t, c in zip(hyper_edge_set_dict_list, non_fictitious_addresses_list, true_shape_list, current_shape_list): graph = JaxGraph(hyper_edge_sets=e, non_fictitious_addresses=n, true_shape=t, current_shape=c) graph_list.append(graph) return graph_list def concatenate_graphs_jax(graph_list: list[JaxGraph]) -> JaxGraph: """ Concatenates multiple JaxGraphs into a single JaxGraph. This function merges a sequence of jax graphs by combining their non-fictitious addresses, hyper-edge sets, and shapes into one unified JaxGraph instance. Address offsets are temporarily applied to avoid collisions between vertex indices during hyper-edge set concatenation, then reverted to preserve the integrity of the original JaxGraph objects. :param graph_list: A list of JaxGraph instances to be concatenated. :return: A new JaxGraph object representing the concatenation of all input graphs. :raises ValueError: If `graph_list` is empty. :note: The input graphs are temporarily modified to apply address offsets but are restored to their original state before the function returns. """ if not graph_list: raise ValueError("graph_list must contain at least one JaxGraph") n_addresses_list = [len(graph.non_fictitious_addresses) for graph in graph_list] offset_list = [sum(n_addresses_list[:i]) for i in range(len(n_addresses_list))] non_fictitious_addresses = jnp.concatenate([graph.non_fictitious_addresses for graph in graph_list], axis=0) true_shape = sum_shapes_jax([graph.true_shape for graph in graph_list]) current_shape = sum_shapes_jax([graph.current_shape for graph in graph_list]) [graph.offset_addresses(offset=offset) for graph, offset in zip(graph_list, offset_list)] hyper_edge_sets = { k: concatenate_hyper_edge_sets_jax([graph.hyper_edge_sets[k] for graph in graph_list]) for k in graph_list[0].hyper_edge_sets } [graph.offset_addresses(offset=-offset) for graph, offset in zip(graph_list, offset_list)] return JaxGraph( hyper_edge_sets=hyper_edge_sets, non_fictitious_addresses=non_fictitious_addresses, true_shape=true_shape, current_shape=current_shape, ) def check_hyper_edge_set_dict_type_jax(hyper_edge_set_dict: dict[str, JaxHyperEdgeSet]) -> None: """ Validate that the provided mapping is a dictionary of JaxHyperEdgeSet instances. :param hyper_edge_set_dict: A mapping from string keys to JaxHyperEdgeSet objects. :raises TypeError: If `hyper_edge_set_dict` is not a dictionary, or if any value in it is not an JaxHyperEdgeSet. """ if not isinstance(hyper_edge_set_dict, dict): raise TypeError("Provided 'hyper_edge_set_dict' is not a 'dict', but a {}.".format(type(hyper_edge_set_dict))) for key, hyper_edge_set in hyper_edge_set_dict.items(): if not isinstance(hyper_edge_set, JaxHyperEdgeSet): raise TypeError("Item associated with '{}' key is not an 'JaxHyperEdgeSet'.".format(key)) def check_valid_addresses_jax(hyper_edge_set_dict: dict[str, JaxHyperEdgeSet], n_addresses: jax.Array) -> None: """ Ensure that all address indices in each JaxHyperEdgeSet are valid with respect to the registry. Iterates over all hyper-edge sets in `hyper_edge_set_dict` and, if a hyper-edge set defines `port_names`, checks that its integer-coded addresses do not exceed the provided count array. :param hyper_edge_set_dict: Mapping from hyper-edge set names to JaxHyperEdgeSet objects containing address arrays. :param n_addresses: 1D array where each entry gives the number of valid addresses for the corresponding hyper-edge set. :raises AssertionError: If any address in any hyper-edge set is outside the valid range (i.e., not less than the corresponding entry in `n_addresses`). """ for key, hyper_edge_set in hyper_edge_set_dict.items(): if hyper_edge_set.port_names is not None: assert jnp.all(hyper_edge_set.port_array < n_addresses) def get_statistics_jax(graph: JaxGraph, axis: int | None = None, norm_graph: JaxGraph | None = None) -> dict: """ Extract summary statistics from each feature array in the jax graph's hyper-edge sets. For every feature of every hyper-edge in `graph`, computes: - Root Mean Squared Error (RMSE) - Mean Absolute Error (MAE) - First and second moments (mean, standard deviation) - Range and quantiles (min, 10th, 25th, 50th, 75th, 90th, max) If `norm_graph` is provided, then it also returns normalized metrics: - Normalized RMSE (nrmse) - Normalized MAE (nmae) :param graph: JaxGraph object containing hyper-edge sets with feature dictionaries. :param axis: Axis along which to compute statistics. If None, statistics are computed over the flattened array. :param norm_graph: Optional JaxGraph whose features serve as normalization reference. :return: A dictionary mapping keys of the form ``"{hyper_edge_set_name}/{feature_name}/{stat}"`` to their computed values. Values are floats or numpy arrays depending on `axis`. """ # Convert fictitious features to NaN. for key, hyper_edge_set in graph.hyper_edge_sets.items(): mask = hyper_edge_set.non_fictitious if hyper_edge_set.feature_array is not None: graph.hyper_edge_sets[key].feature_array = graph.hyper_edge_sets[key].feature_array.at[mask == 0].set(jnp.nan) info = {} for object_name, hyper_edge_set in graph.hyper_edge_sets.items(): if hyper_edge_set.feature_dict is not None: for feature_name, array in hyper_edge_set.feature_dict.items(): if array.size == 0: if axis == 1: array = jnp.array([[0.0]]) else: array = jnp.array([0.0]) # Root Mean Squared Error rmse = jnp.sqrt(jnp.nanmean(array**2, axis=axis)) info["{}/{}/rmse".format(object_name, feature_name)] = rmse if norm_graph is not None: norm_array = norm_graph.hyper_edge_sets[object_name].feature_dict[feature_name] norm_array = norm_array - jnp.nanmean(norm_array) nrmse = rmse / (jnp.sqrt(jnp.nanmean(norm_array**2, axis=axis)) + 1e-9) info["{}/{}/nrmse".format(object_name, feature_name)] = nrmse # Mean Absolute Error mae = jnp.nanmean(jnp.abs(array), axis=axis) info["{}/{}/mae".format(object_name, feature_name)] = mae if norm_graph is not None: norm_array = norm_graph.hyper_edge_sets[object_name].feature_dict[feature_name] norm_array = norm_array - jnp.nanmean(norm_array) nmae = mae / (jnp.nanmean(jnp.abs(norm_array), axis=axis) + 1e-9) info["{}/{}/nmae".format(object_name, feature_name)] = nmae # Moments info["{}/{}/mean".format(object_name, feature_name)] = jnp.nanmean(array, axis=axis) info["{}/{}/std".format(object_name, feature_name)] = jnp.nanstd(array, axis=axis) # Quantiles info["{}/{}/max".format(object_name, feature_name)] = jnp.nanmax(array, axis=axis) info["{}/{}/90th".format(object_name, feature_name)] = jnp.nanpercentile(array, q=90, axis=axis) info["{}/{}/75th".format(object_name, feature_name)] = jnp.nanpercentile(array, q=75, axis=axis) info["{}/{}/50th".format(object_name, feature_name)] = jnp.nanpercentile(array, q=50, axis=axis) info["{}/{}/25th".format(object_name, feature_name)] = jnp.nanpercentile(array, q=25, axis=axis) info["{}/{}/10th".format(object_name, feature_name)] = jnp.nanpercentile(array, q=10, axis=axis) info["{}/{}/min".format(object_name, feature_name)] = jnp.nanmin(array, axis=axis) return info