Source code for energnn.model.coupler.message_passing.recurrent_coupler
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
import logging
import jax
import jax.numpy as jnp
from flax import nnx
from energnn.graph import JaxGraph
from energnn.model.utils import MLP
from .message_passing_function import MessagePassingFunction
from ..coupler import Coupler
logger = logging.getLogger(__name__)
[docs]
class RecurrentCoupler(Coupler):
r"""
Simplified version of the Neural Ordinary Differential Equation solver.
The following recurrent system is used.:
.. math::
\forall a \in \mathcal{A}_x, h_a(t+\delta t) = h_a(t+\delta t) +
\delta t \times \phi_\theta(\psi^1_\theta(h;x)_a, \dots, \psi^n_\theta(h;x)_a),
with the following initial condition:
.. math::
\forall a \in \mathcal{A}_x, h_a(t=0) = [0, \dots, 0].
:param phi: Outer MLP :math:`\phi_\theta`.
:param message_functions: List of message functions :math:`(\psi^i_\theta)_i`.
:param n_steps: Number of message passing steps.
"""
def __init__(
self,
phi: MLP,
message_functions: list[MessagePassingFunction],
n_steps: int,
):
super().__init__()
self.phi = phi
self.message_functions = nnx.List(message_functions)
self.n_steps = n_steps
self.dt = 1 / self.n_steps
def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[jax.Array, dict]:
def F(t, coordinates, graph):
"""Residual function."""
messages = []
for m in self.message_functions:
message, info = m(graph=graph, coordinates=coordinates)
messages.append(message)
messages = jnp.concatenate(messages, axis=-1)
return self.phi(messages)
h = jnp.zeros([jnp.shape(graph.non_fictitious_addresses)[0], self.phi.out_size])
dt = 1 / self.n_steps
for _ in range(self.n_steps):
h = h + dt * F(0, h, graph)
return h, {}
@staticmethod
def log_solved():
"""Log a message indicating successful ODE solve."""
logger.info("ODE solved")