Source code for energnn.model.coupler.coupler

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod

import jax
from flax import nnx

from energnn.graph import JaxGraph


[docs] class Coupler(nnx.Module, ABC): """Interface for a coupler. A coupler takes as input a graph and returns latent coordinates for each address. Graph information should be injected into the latent coordinates in a permutation-equivariant manner. """ @abstractmethod def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[jax.Array, dict]: """Compute latent coordinates from the input graph. :param graph: Input graph to process. :param get_info: If True, returns additional info for tracking purpose. :return: A tuple containing: - Latent coordinates array with shape (num_addresses, latent_dim) - A dictionary with additional information if get_info=True, empty dict otherwise :raises NotImplementedError: If the subclass does not override this method. """ raise NotImplementedError