# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from abc import ABC, abstractmethod
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import initializers
from flax.typing import Initializer
from energnn.graph import GraphStructure, JaxGraph
from energnn.model.utils import Activation, MLP, gather, scatter_add
[docs]
class MessagePassingFunction(nnx.Module, ABC):
r"""Interface for a message function :math:`\xi_\theta` in a GNN message passing scheme."""
@abstractmethod
def __call__(self, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False) -> tuple[jax.Array, dict]:
"""Should take as input a tuple (graph, coordinates) and return new coordinates."""
raise NotImplementedError
[docs]
class LocalSumMessagePassingFunction(MessagePassingFunction):
r"""
Local sum-based message function module for GNN message passing.
This module aggregates messages from each node's local neighborhood by applying
a class- and port-specific MLP :math:`\xi^{c,o}_\theta` to hyper-edge features and neighbor coordinates,
summing the results across all incoming ports, and applying a final activation :math:`\sigma`.
For each address :math:`a`, the output is defined as:
.. math::
\psi_\theta(h,x)_a = \sigma \left( \sum_{(c,e,o)\in \mathcal{N}_x(a)} \xi^{c,o}_\theta(h_e, x_e)\right),
where :math:`\xi^{c,o}_\theta` is a class-specific and port-specific MLP, :math:`\sigma` is an
element-wise activation function, and :math:`h_e := (h_{o(e)})_{o \in {\mathcal{O}^c}}` is the concatenation of
port coordinates of hyper-edge :math:`e`.
:param in_graph_structure: Input graph structure.
:param in_array_size: Size of the input coordinate arrays.
:param hidden_sizes: Hidden sizes of the MLPs :math:`\xi^{c,o}_\theta`.
:param activation: Activation function for the MLPs :math:`\xi^{c,o}_\theta`.
:param out_size: Output size of the MLPs :math:`\xi^{c,o}_\theta`.
:param use_bias: Whether to use bias in the MLPs :math:`\xi^{c,o}_\theta`.
:param kernel_init: Kernel initializer for the MLPs :math:`\xi^{c,o}_\theta`.
:param bias_init: Bias initializer for the MLPs :math:`\xi^{c,o}_\theta`.
:param final_activation: Final activation function for the MLPs :math:`\xi^{c,o}_\theta`.
:param outer_activation: Activation function :math:`\sigma` applied over the output.
:param encoded_feature_size: None if the input data has not been encoded, otherwise the size of the encoded features.
:param port_scatter_blacklist: Dictionary mapping hyper-edge set keys to lists of port keys to be excluded from the sum.
:param seed: Seed for RNG streams for weight initialization.
"""
def __init__(
self,
in_graph_structure: GraphStructure,
in_array_size: int,
hidden_sizes: list[int],
activation: Activation = nnx.relu,
out_size: int = 1,
use_bias: bool = True,
kernel_init: Initializer = initializers.lecun_normal(),
bias_init: Initializer = initializers.zeros_init(),
final_activation: Activation | None = None,
outer_activation: Activation = nnx.tanh,
encoded_feature_size: int | None = None,
port_scatter_blacklist: dict[str, list[str]] | None = None,
seed: int | None = None,
rngs: nnx.Rngs | None = None,
):
self.in_graph_structure = in_graph_structure
self.in_array_size = in_array_size
self.hidden_sizes = hidden_sizes
self.activation = activation
self.out_size = out_size
self.use_bias = use_bias
self.kernel_init = kernel_init
self.bias_init = bias_init
self.final_activation = final_activation
self.outer_activation = outer_activation
self.encoded_feature_size = encoded_feature_size
if port_scatter_blacklist is None:
self.port_scatter_blacklist = {}
else:
self.port_scatter_blacklist = port_scatter_blacklist
self.mlp_tree = self._build_mlp_tree(seed=seed, rngs=rngs)
def _build_mlp_tree(self, seed: int = 0, rngs: nnx.Rngs | None = None) -> dict[str, dict[str, MLP]]:
if rngs is None:
rngs = nnx.Rngs(seed)
elif seed is not None:
raise ValueError("Seed must be None when rngs are provided.")
mlp_tree = {}
for key, hyper_edge_set_structure in self.in_graph_structure.hyper_edge_sets.items():
if hyper_edge_set_structure.port_list is not None and len(hyper_edge_set_structure.port_list) > 0:
n_ports = len(hyper_edge_set_structure.port_list)
in_size = self.in_array_size * n_ports
if hyper_edge_set_structure.feature_list is not None and len(hyper_edge_set_structure.feature_list) > 0:
if self.encoded_feature_size is not None:
in_size += self.encoded_feature_size
else:
in_size += len(hyper_edge_set_structure.feature_list)
if key not in mlp_tree.keys():
mlp_tree[key] = {}
for port_key in hyper_edge_set_structure.port_list:
if port_key not in self.port_scatter_blacklist.get(key, []):
mlp_tree[key][port_key] = MLP(
in_size=in_size,
hidden_sizes=self.hidden_sizes,
activation=self.activation,
out_size=self.out_size,
use_bias=self.use_bias,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
final_activation=self.final_activation,
rngs=rngs,
)
return nnx.data(mlp_tree)
def __call__(self, *, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False) -> tuple[jax.Array, dict]:
def sum_over_edges(_accumulator, edge_mlp_tuple):
"""Sums the output of class and port specific MLPs through ports of all hyper-edge sets in the graph."""
hyper_edge_set, mlp_dict = edge_mlp_tuple
input_array = []
if hyper_edge_set.feature_names is not None:
input_array.append(hyper_edge_set.feature_array)
for port_name, port_array in hyper_edge_set.port_dict.items():
input_array.append(gather(coordinates=coordinates, addresses=port_array))
input_array = jnp.concatenate(input_array, axis=-1)
non_fictitious_mask = jnp.expand_dims(hyper_edge_set.non_fictitious, -1)
def sum_over_ports(__accumulator: jax.Array, mlp_port: tuple[MLP, jax.Array]) -> jax.Array:
"""Sums the outputs of port-specific MLPs through ports of a given hyper-edge set."""
mlp, _port_array = mlp_port
increment = mlp(input_array * non_fictitious_mask) * non_fictitious_mask
return scatter_add(accumulator=__accumulator, increment=increment, addresses=_port_array)
mlp_port_dict = {port_name: (mlp, hyper_edge_set.port_dict[port_name]) for port_name, mlp in mlp_dict.items()}
return jax.tree.reduce(
sum_over_ports, mlp_port_dict, initializer=_accumulator, is_leaf=lambda x: isinstance(x, tuple)
)
initializer = jnp.zeros((coordinates.shape[0], self.out_size))
edge_mlp_dict = {key: (hyper_edge_set, self.mlp_tree[key]) for key, hyper_edge_set in graph.hyper_edge_sets.items()}
accumulator = jax.tree.reduce(
sum_over_edges,
edge_mlp_dict,
initializer=initializer,
is_leaf=lambda x: isinstance(x, tuple),
)
return self.outer_activation(accumulator), {}
[docs]
class IdentityMessagePassingFunction(MessagePassingFunction):
r"""
Identity local message function module for GNN message passing.
This module returns the node features unchanged as the local message.
It implements the identity mapping on node features:
.. math::
h^\rightarrow_a = h_a
"""
def __init__(self):
pass
def __call__(self, *, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False) -> tuple[jax.Array, dict]:
return coordinates, {}