Source code for energnn.model.gnn

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

import jax
from flax import nnx

from energnn.graph import JaxGraph
from .coupler import Coupler
from .decoder import Decoder
from .encoder import Encoder
from .normalizer import Normalizer


[docs] class GNN(nnx.Module): """ Simple Graph Neural Network (GNN) model designed to handle Hyper Heterogeneous Multi Graphs (H2MGs). The model consists of a normalization step, an encoding step, a coupling step, and a decoding step. The decoder can either be invariant or equivariant, depending on the task requirements. :param normalizer: Maps the input features to a learning-compatible range. :type normalizer: Normalizer :param encoder: Embeds hyper-edge set features into a latent space. :type encoder: Encoder :param coupler: Outputs latent coordinates for each address present in the input graph. :type coupler: Coupler :param decoder: Maps latent coordinates and encoded graph to a meaningful output. :type decoder: Decoder """ def __init__(self, normalizer: Normalizer, encoder: Encoder, coupler: Coupler, decoder: Decoder): self.normalizer = normalizer self.encoder = encoder self.coupler = coupler self.decoder = decoder
[docs] def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph | jax.Array, dict]: """ Processes a given graph through a sequence of steps: normalization, encoding, coupling, and decoding. The method applies a series of transformations to the input graph and returns a decoded graph / array along with optional processing information. :param graph: The input graph to be processed. :param get_info: A boolean indicating whether detailed processing information should be returned. Defaults to False. :return: A tuple consisting of the processed decoded graph / array and an optional dictionary with detailed information about each processing step if `get_info` is True. """ info = {} normalized_graph, info["normalization"] = self.normalizer(graph=graph, get_info=get_info) encoded_graph, info["encoding"] = self.encoder(graph=normalized_graph, get_info=get_info) latent_coordinates, info["coupling"] = self.coupler(graph=encoded_graph, get_info=get_info) output, info["decoding"] = self.decoder(coordinates=latent_coordinates, graph=encoded_graph, get_info=get_info) return output, info
[docs] def forward_batch(self, *, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph | jax.Array, dict]: """Applies the model to a batch of graphs. Only the encoder, coupler, and decoder modules are vmapped, while the normalization module is not. :param graph: Batch of input graphs. :param get_info: Whether to return additional information about the processing steps. """ def apply_core(encoder, coupler, decoder, graph, get_info): info = {} encoded_graph, info["encoding"] = encoder(graph=graph, get_info=get_info) latent_coordinates, info["coupling"] = coupler(graph=encoded_graph, get_info=get_info) output, info["decoding"] = decoder(coordinates=latent_coordinates, graph=encoded_graph, get_info=get_info) return output, info normalized_graph, info_norm = self.normalizer(graph=graph, get_info=get_info) output, info_core = jax.vmap(apply_core, in_axes=[None, None, None, 0, None], out_axes=0)( self.encoder, self.coupler, self.decoder, normalized_graph, get_info ) return output, info_norm | info_core