Source code for energnn.model.decoder.decoder
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from abc import ABC, abstractmethod
import jax
from flax import nnx
from energnn.graph import JaxGraph
[docs]
class Decoder(ABC, nnx.Module):
"""Interface for all decoders.
A decoder takes as input latent coordinates and an encoded graph context,
and produces either a new graph with predictions or a global output vector.
"""
@abstractmethod
def __call__(
self, *, graph: JaxGraph, coordinates: jax.Array, get_info: bool = False
) -> tuple[JaxGraph | jax.Array, dict]:
"""Decode latent coordinates into predictions.
:param graph: Encoded graph providing context for decoding.
:param coordinates: Latent coordinates array with shape (num_addresses, latent_dim).
:param get_info: If True, returns additional info for tracking purpose.
:return: A tuple containing:
- Either a new JaxGraph with prediction features or a global output array
- A dictionary with additional information if get_info=True, empty dict otherwise
:raises NotImplementedError: If the subclass does not override this method.
"""
raise NotImplementedError