# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax import Device
from jax.tree_util import register_pytree_node_class
from energnn.graph.jax.hyper_edge_set import JaxHyperEdgeSet
from energnn.graph.jax.utils import jnp_to_np, np_to_jnp
from energnn.graph.shape import GraphShape
HYPER_EDGE_SETS = "hyper_edge_sets"
ADDRESSES = "addresses"
[docs]
@register_pytree_node_class
class JaxGraphShape(dict):
"""
PyTree container for storing the number of objects in each class, and addresses in the graph.
This class inherits from `dict` and stores two keys:
:param hyper_edge_sets: Dictionary of that contains the number of objects for each class.
:param addresses: Number of addresses in the graph.
The PyTree methods ``tree_flatten`` and ``tree_unflatten`` make this object
compatible with JAX transformations (jit, vmap, etc.).
"""
def __init__(self, *, hyper_edge_sets: dict[str, jax.Array], addresses: jax.Array) -> None:
super().__init__()
self[HYPER_EDGE_SETS] = hyper_edge_sets
self[ADDRESSES] = addresses
@classmethod
def from_dict(cls, hyper_edge_set_dict: dict[str, JaxHyperEdgeSet], non_fictitious: jax.Array | None) -> JaxGraphShape:
"""
Builds a new JaxGraphShape object from a hyper-edge set dictionary and registry.
:param hyper_edge_set_dict: Mapping from a hyper-edge set class name to a `JaxHyperEdgeSet` instance.
:param non_fictitious: Optional numpy array whose last dimension indicates registry size.
:return: New JaxGraphShape instance.
"""
hyper_edge_set_shape_dict = {k: jnp.array(v.n_obj) for (k, v) in hyper_edge_set_dict.items()}
if non_fictitious is not None:
addresses = jnp.array(non_fictitious.shape[0])
else:
addresses = jnp.array([0])
return cls(hyper_edge_sets=hyper_edge_set_shape_dict, addresses=addresses)
def to_jsonable_dict(self):
"""
Serialize JaxGraphShape to JSON-friendly dict.
:return: Dict with 'JaxHyperEdgeSet' mapping to ints and 'addresses' as int.
"""
return {HYPER_EDGE_SETS: {k: int(v) for k, v in self.hyper_edge_sets.items()}, ADDRESSES: int(self.addresses)}
@classmethod
def from_jsonable_dict(cls, count_shape: dict) -> JaxGraphShape:
"""
Deserialize JaxGraphShape from a JSON-friendly dictionary.
:param count_shape: Dict with 'hyper_edge_sets' and 'addresses'.
:return: Reconstructed JaxGraphShape.
"""
hyper_edge_sets = {k: jnp.array(v) for k, v in count_shape[HYPER_EDGE_SETS].items()}
addresses = jnp.array(count_shape[ADDRESSES])
return cls(hyper_edge_sets=hyper_edge_sets, addresses=addresses)
@classmethod
def max(cls, a: JaxGraphShape, b: JaxGraphShape) -> JaxGraphShape:
"""
Returns the maximum shape of 2 graph shapes.
:param a: A first graph shape.
:param b: A second graph shape.
:return: A graph shape with maxima per hyper-edge set class and addresses.
"""
hyper_edge_set_classes = set(list(a.hyper_edge_sets.keys()) + list(b.hyper_edge_sets.keys()))
hyper_edge_set_shape_max = {}
for hyper_edge_set_class in hyper_edge_set_classes:
hyper_edge_set_shape_max[hyper_edge_set_class] = jnp.maximum(
a.hyper_edge_sets.get(hyper_edge_set_class, -jnp.inf), b.hyper_edge_sets.get(hyper_edge_set_class, -jnp.inf)
)
addresses = jnp.maximum(a.addresses, b.addresses)
return cls(hyper_edge_sets=hyper_edge_set_shape_max, addresses=addresses)
@classmethod
def sum(cls, a: JaxGraphShape, b: JaxGraphShape) -> JaxGraphShape:
"""
Returns the sum shape of 2 graph shapes.
:param a: A first graph shape.
:param b: A second graph shape.
:return: A graph shape with summed counts per hyper-edge set class and addresses.
"""
hyper_edge_set_classes = set(list(a.hyper_edge_sets.keys()) + list(b.hyper_edge_sets.keys()))
hyper_edge_set_shape_max = {}
for hyper_edge_set_class in hyper_edge_set_classes:
hyper_edge_set_shape_max[hyper_edge_set_class] = a.hyper_edge_sets.get(
hyper_edge_set_class, 0
) + b.hyper_edge_sets.get(hyper_edge_set_class, 0)
addresses = a.addresses + b.addresses
return cls(hyper_edge_sets=hyper_edge_set_shape_max, addresses=addresses)
[docs]
def tree_flatten(self):
"""
Flatten the JaxGraphShape for JAX PyTree compatibility.
:returns: Flat children and auxiliary data (the keys order).
"""
children = self.values()
aux = self.keys()
return children, aux
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children) -> JaxGraphShape:
"""
Reconstruct a JaxGraphShape from flattened data, required for JAX compatibility.
:param aux_data: Sequence of keys matching the order of the children.
:param children: Sequence of array values.
:return: A reconstructed JaxGraphShape instance.
"""
d = dict(zip(aux_data, children))
return cls(hyper_edge_sets=d[HYPER_EDGE_SETS], addresses=d[ADDRESSES])
@property
def hyper_edge_sets(self) -> dict[str, jax.Array]:
"""Dictionary of edge shapes."""
return self[HYPER_EDGE_SETS]
@property
def addresses(self) -> jax.Array:
"""Number of addresses in the graph."""
return self[ADDRESSES]
[docs]
@classmethod
def from_numpy_shape(cls, shape: GraphShape, device: Device | None = None, dtype: str = "float32") -> JaxGraphShape:
"""
Convert a classical numpy shape to a jax.numpy format for GNN processing.
This method transforms all array-like attributes of a ``GraphShape`` object into
their JAX equivalents, allowing efficient use with JAX transformations and accelerators.
:param shape: A shape object containing NumPy arrays to convert.
:param device: Optional JAX device (e.g., CPU, GPU) to place the converted arrays on.
If None, JAX uses the default device.
:param dtype: Desired floating-point precision for converted arrays (e.g., "float32", "float64").
:return: A JAX-compatible version of the shape, ready for use in GNN pipelines.
"""
hyper_edge_sets = np_to_jnp(shape.hyper_edge_sets, device=device, dtype=dtype)
addresses = np_to_jnp(shape.addresses, device=device, dtype=dtype)
return cls(hyper_edge_sets=hyper_edge_sets, addresses=addresses)
[docs]
def to_numpy_shape(self) -> GraphShape:
"""
Convert a jax.numpy shape for GNN processing to a classical numpy shape.
This method transforms the internal JAX arrays of the shape back into standard
NumPy arrays, enabling compatibility with non-JAX components.
:return: A classical ``GraphShape`` object with NumPy arrays.
"""
hyper_edge_sets = jnp_to_np(self.hyper_edge_sets)
addresses = jnp_to_np(self.addresses)
return GraphShape(hyper_edge_sets=hyper_edge_sets, addresses=addresses)
@property
def array(self) -> jax.Array:
"""Concatenated hyper-edge set shapes as a single jax array."""
return jnp.stack([v for v in self.hyper_edge_sets.values()], axis=-1)
@property
def is_single(self) -> bool:
"""True if the jax array is 1-D."""
return len(self.array.shape) == 1
@property
def is_batch(self) -> bool:
"""True if the jax array is 2-D."""
return len(self.array.shape) == 2
@property
def n_batch(self) -> int:
"""
Return the batch size.
:raises ValueError: If JaxGraphShape is not batched.
"""
if not self.is_batch:
raise ValueError("JaxGraphShape is not batched.")
return self.array.shape[0]
def collate_shapes_jax(shape_list: list[JaxGraphShape]) -> JaxGraphShape:
"""
Batches a list of JaxGraphShape into one batched JaxGraphShape.
:param shape_list: List of JaxGraphShape objects (must share hyper-edge set keys).
:return: Batched JaxGraphShape with stacked arrays.
:raises ValueError: If the input list is empty.
"""
if not shape_list:
raise ValueError("Empty shape list provided to collate_shapes_jax.")
hyper_edge_set_shape_batch = {
k: jnp.stack([s.hyper_edge_sets[k] for s in shape_list], axis=0) for k in shape_list[0].hyper_edge_sets
}
addresses_batch = jnp.stack([s.addresses for s in shape_list], axis=0)
return JaxGraphShape(hyper_edge_sets=hyper_edge_set_shape_batch, addresses=addresses_batch)
def separate_shapes_jax(shape_batch: JaxGraphShape) -> list[JaxGraphShape]:
"""
Splits a batched JaxGraphShape into individual JaxGraphShape instances.
:param shape_batch: JaxGraphShape with 2D hyper-edge sets and address arrays.
:return: List of JaxGraphShape (one per batch).
:raises ValueError: If input is not batched.
"""
if not shape_batch.is_batch:
raise ValueError("Input JaxGraphShape must be batched for separation.")
addresses_list = jnp.unstack(shape_batch.addresses, axis=0)
a = {k: jnp.unstack(shape_batch.hyper_edge_sets[k]) for k in shape_batch.hyper_edge_sets}
hyper_edge_set_list = [dict(zip(a, t)) for t in zip(*a.values())]
shape_list = []
for a, e in zip(addresses_list, hyper_edge_set_list):
shape = JaxGraphShape(hyper_edge_sets=e, addresses=a)
shape_list.append(shape)
return shape_list
def max_shape_jax(graph_shape_list: list[JaxGraphShape]) -> JaxGraphShape:
"""
Returns the maximum jax graph shape from a list of jax graph shapes.
If some objects do not appear in some shapes, then those objects
are systematically included in the output.
:param graph_shape_list: List of jax graph shapes to be compared.
:return: JaxGraphShape with maxima per hyper-edge set class and addresses.
:raises ValueError: If the list is empty or contains non-JaxGraphShape.
"""
if not graph_shape_list:
raise ValueError("Empty input list given for max_shape_jax.")
max_graph_shape = graph_shape_list[0]
for graph_shape in graph_shape_list:
if not isinstance(graph_shape, JaxGraphShape):
raise ValueError("Invalid input in graph_list, expected JaxGraphShape.")
max_graph_shape = JaxGraphShape.max(max_graph_shape, graph_shape)
return max_graph_shape
def sum_shapes_jax(graph_shape_list: list[JaxGraphShape]) -> JaxGraphShape:
"""
Returns the sum jax graph shape from a list of jax graph shapes.
:param graph_shape_list: List of jax graph shapes to be summed.
:return: JaxGraphShape with summed counts per hyper-edge set class and addresses.
:raises ValueError: If the list is empty or contains non-JaxGraphShape.
"""
if not graph_shape_list:
raise ValueError("Empty input list given for sum_shapes_jax.")
sum_graph_shape = graph_shape_list[0]
for graph_shape in graph_shape_list[1:]:
if not isinstance(graph_shape, JaxGraphShape):
raise ValueError("Invalid input in graph_list, expected JaxGraphShape.")
sum_graph_shape = JaxGraphShape.sum(sum_graph_shape, graph_shape)
return sum_graph_shape