Source code for energnn.model.decoder.invariant_decoder
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
import jax
import jax.numpy as jnp
from energnn.graph import JaxGraph
from energnn.model.utils import MLP
from .decoder import Decoder
[docs]
class InvariantDecoder(Decoder, ABC):
"""Abstract base class for invariant decoders that produce global outputs.
Invariant decoders aggregate information from all addresses in a permutation-invariant
manner to produce a single global output vector.
"""
@abstractmethod
def __call__(self, *, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False) -> tuple[jax.Array, dict]:
"""Decode latent coordinates into a global decision vector.
:param graph: Input graph to decode.
:param coordinates: Coordinates stored as JAX array.
:param get_info: If True, returns additional info for tracking purpose.
:return: Tuple containing decision vector and info dictionary.
:raises NotImplementedError: If subclass does not override this method.
"""
raise NotImplementedError
[docs]
class SumInvariantDecoder(InvariantDecoder):
r"""
Sum invariant decoder, that sums the information of all addresses.
.. math::
\hat{y} = \phi_\theta \left( \sum_{a \in \mathcal{A}(x)} \psi_\theta(h_a)\right),
where :math:`\phi_\theta` (outer) and :math:`\psi_\theta` (inner) are both trainable MLPs.
:param psi: Inner MLP :math:`\psi_\theta`.
:param phi: Outer MLP :math:`\phi_\theta`.
"""
def __init__(self, *, psi: MLP, phi: MLP) -> None:
super().__init__()
self.psi = psi
self.phi = phi
def __call__(self, *, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False) -> tuple[jax.Array, dict]:
h = self.psi(coordinates)
h = h * jnp.expand_dims(graph.non_fictitious_addresses, -1)
h = jnp.sum(h, axis=0)
out = self.phi(h)
return out, {}
[docs]
class MeanInvariantDecoder(InvariantDecoder):
r"""
Mean invariant decoder, that averages the information of all addresses.
.. math::
\hat{y} = \phi_\theta \left( \frac{1}{\vert \mathcal{A}(x) \vert} \sum_{a \in \mathcal{A}(x)} \psi_\theta(h_a) \right),
where :math:`\phi_\theta` (outer) and :math:`\psi_\theta` (inner) are both trainable MLPs.
:param psi: Inner MLP :math:`\psi_\theta`.
:param phi: Outer MLP :math:`\phi_\theta`.
"""
def __init__(self, *, psi: MLP, phi: MLP) -> None:
super().__init__()
self.psi = psi
self.phi = phi
def __call__(self, *, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False) -> tuple[jax.Array, dict]:
numerator = self.psi(coordinates)
numerator = numerator * jnp.expand_dims(graph.non_fictitious_addresses, -1)
numerator = jnp.sum(numerator, axis=0)
denominator = jnp.sum(graph.non_fictitious_addresses, axis=0) + 1e-9
return self.phi(numerator / denominator), {}