Source code for energnn.model.encoder.encoder
# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from flax import nnx
from energnn.graph import JaxGraph
[docs]
class Encoder(nnx.Module, ABC):
@abstractmethod
def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict]:
"""Encode the input graph into a graph with the same hyper-edge set classes and features.
:param graph: Input graph to encode.
:param get_info: If True, returns additional info for tracking purpose.
:return: A tuple containing:
- Encoded graph with transformed features
- A dictionary with additional information if get_info=True, empty dict otherwise
:raises NotImplementedError: If the subclass does not override this method.
"""
raise NotImplementedError
[docs]
class IdentityEncoder(Encoder):
r"""Identity encoder that returns the input graph unchanged.
.. math::
\tilde{x} = x
"""
def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict]:
"""Apply the identity encoder and return the input graph without changes.
:param context: Input graph to encode.
:param get_info: If True, returns additional info for tracking purpose.
"""
return graph, {}