Source code for energnn.model.normalizer.normalizer

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod

from flax import nnx

from energnn.graph import JaxGraph


[docs] class Normalizer(nnx.Module, ABC): """Interface for a normalizer. A normalizer transforms the input graph features into a distribution more suitable for neural network training (e.g., standardization, normalization). """ @abstractmethod def __call__(self, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict]: """Normalize the input graph features. :param graph: Input graph to normalize. :param get_info: If True, returns additional info for tracking purpose. :return: A tuple containing: - Normalized graph with transformed features - A dictionary with additional information if get_info=True, empty dict otherwise :raises NotImplementedError: If the subclass does not override this method. """ raise NotImplementedError