Source code for energnn.model.coupler.message_passing.node_coupler
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
import logging
import diffrax
import jax
import jax.numpy as jnp
from flax import nnx
from energnn.graph import JaxGraph
from energnn.model.utils import MLP
from .message_passing_function import MessagePassingFunction
from ..coupler import Coupler
logger = logging.getLogger(__name__)
[docs]
class NODECoupler(Coupler):
r"""
Output coordinates are computed by solving a Neural Ordinary Differential Equation.
The following ordinary differential equation is integrated between 0 and 1:
.. math::
\forall a \in \mathcal{A}_x, \frac{dh_a}{dt} = \phi_\theta(\psi^1_\theta(h;x)_a, \dots, \psi^n_\theta(h;x)_a),
with the following initial condition:
.. math::
\forall a \in \mathcal{A}_x, h_a(t=0) = [0, \dots, 0].
Implementation relies on Patrick Kidger's `Diffrax <https://docs.kidger.site/diffrax/>`_.
:param phi: Outer MLP :math:`\phi_\theta`.
:param message_functions: List of message functions :math:`(\psi^i_\theta)_i`.
:param latent_dimension: Dimension of address latent coordinates.
:param dt: Initial step size value.
:param stepsize_controller: Controller for adaptive step size methods.
:param adjoint: Method used for backpropagation.
:param solver: Numerical solver for the ODE.
:param max_steps: Maximum number of steps allowed for the solving of the ODE.
"""
def __init__(
self,
phi: MLP,
message_functions: list[MessagePassingFunction],
dt: float,
stepsize_controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
solver: diffrax.AbstractSolver,
max_steps: int,
):
super().__init__()
self.phi = phi
self.message_functions = nnx.List(message_functions)
self.dt = dt
self.stepsize_controller = stepsize_controller
self.solver = solver
self.adjoint = adjoint
self.max_steps = max_steps
def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[jax.Array, dict]:
def F(t, coordinates, graph):
"""Second member of the Neural ODE."""
messages = []
for m in self.message_functions:
message, info = m(graph=graph, coordinates=coordinates)
messages.append(message)
messages = jnp.concatenate(messages, axis=-1)
return self.phi(messages)
h_0 = jnp.zeros([jnp.shape(graph.non_fictitious_addresses)[0], self.phi.out_size])
solution = diffrax.diffeqsolve(
terms=diffrax.ODETerm(F),
solver=self.solver,
t0=0,
t1=1,
dt0=self.dt,
y0=h_0,
saveat=diffrax.SaveAt(t1=True),
args=graph,
stepsize_controller=self.stepsize_controller,
adjoint=self.adjoint,
max_steps=self.max_steps,
)
return solution.ys[-1], {}
@staticmethod
def log_solved():
"""Log a message indicating successful ODE solve."""
logger.info("ODE solved")