Source code for energnn.model.normalizer.center_reduce_normalizer

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

import jax
import jax.numpy as jnp
from flax import nnx

from energnn.graph import GraphStructure, JaxGraph
from energnn.graph.jax import JaxHyperEdgeSet
from .normalizer import Normalizer


class HyperEdgeSetCenterReduceNormalizer(nnx.Module):
    """
    HyperEdgeSetCenterReduceNormalizer normalizes HyperEdgeSet data using a feature-wise mean and variance
    calculation while supporting running averages and bias correction.
    """

    def __init__(
        self,
        n_features: int,
        update_limit: int,
        beta_1: float = 0.9,
        beta_2: float = 0.9,
        epsilon: float = 1e-6,
        use_running_average: bool = False,
    ):
        """
        Initializes the instance with the necessary configurations and state variables for
        adaptive moment estimation and related operations.

        :param n_features: Specifies the number of features to be handled by the class.
        :param update_limit: Indicates the maximum number of updates allowed for this instance.
        :param beta_1: The exponential decay rate for the first moment estimation. Defaults to 0.9.
        :param beta_2: The exponential decay rate for the second moment estimation. Defaults to 0.999.
        :param epsilon: A small value added to prevent division by zero during calculations. Defaults to 1e-6.
        :param use_running_average: Determines whether to use a running average for parameter updates. Defaults to False.
            Automatically set to True in `eval` mode and to `False` in `train` mode.
        """
        self.n_features = n_features
        self.update_limit = nnx.Variable(jnp.array([update_limit]))
        self.use_running_average = use_running_average
        self.epsilon = epsilon
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.updates = nnx.Variable(jnp.array([0]))

        self.mean = nnx.Variable(jnp.zeros(n_features))
        self.var = nnx.Variable(jnp.ones(n_features))

    def __call__(self, x: jax.Array, mask: jax.Array = None):

        # Check input.
        if x.ndim == 2:
            is_batched = False
        elif x.ndim == 3:
            is_batched = True
        else:
            raise ValueError("Input x must be shape (n_items,F) or (B,n_items,F)")
        assert x.shape[-1] == self.n_features

        # If rolling mean and variance should be updated.
        is_training = not self.use_running_average
        # We use jnp.where to handle the updates even if jitted, to avoid TracerBoolConversionError.
        # However, the assignment itself must happen.

        if is_batched:
            current_mean = x.mean(axis=(0, 1), where=(mask != 0.0))
            current_var = x.var(axis=(0, 1), where=(mask != 0.0))
        else:
            current_mean = x.mean(axis=0, where=(mask != 0.0))
            current_var = x.var(axis=0, where=(mask != 0.0))

        if self.mean._can_update or self.var._can_update:
            stop_gradient = jax.lax.stop_gradient
        else:

            def stop_gradient(_x):
                return _x

        should_update = is_training & (self.updates[...] < self.update_limit[...])[0]

        new_mean = jnp.where(
            self.updates[...] == 0,
            current_mean,
            self.beta_1 * self.mean[...] + (1 - self.beta_1) * current_mean,
        )
        new_var = jnp.where(
            self.updates[...] == 0,
            current_var,
            self.beta_2 * self.var[...] + (1 - self.beta_2) * current_var,
        )

        self.mean[...] = stop_gradient(jnp.where(should_update, new_mean, self.mean[...]))
        self.var[...] = stop_gradient(jnp.where(should_update, new_var, self.var[...]))
        self.updates[...] = jnp.where(should_update, self.updates[...] + 1, self.updates[...])

        # Correct bias
        # We add epsilon to denominator to avoid division by zero when updates is 0
        mean_hat = self.mean / (1 - self.beta_1**self.updates + self.epsilon)
        var_hat = self.var / (1 - self.beta_2**self.updates + self.epsilon)

        return (x - mean_hat) / (jnp.sqrt(var_hat) + self.epsilon) * mask


[docs] class CenterReduceNormalizer(Normalizer): r""" Graph-level wrapper that maintains an HyperEdgeSetCenterReduceNormalizer for each hyper-edge set key. For a given feature of a given hyper-edge set class, the output is defined as follows. .. math:: x' = \frac{x - \mu}{\sqrt{\sigma^2} + \epsilon} where :math:`\mu` (resp. :math:`\sigma^2`) is the exponential moving average of the empirical mean (resp. variance) with decay rate `beta_1` (resp. `beta_2`). :param in_structure: GraphStructure of the input graph. :param update_limit: Threshold for the maximum updates to be performed. :param beta_1: Exponential decay rate for the first moment estimates. Defaults to 0.9. :param beta_2: Exponential decay rate for the second moment estimates. Defaults to 0.999. :param epsilon: Small constant added to improve numerical stability. Defaults to 1e-6. :param use_running_average: Flag that indicates whether to use a running average or not. Defaults to False. Automatically set to True in `eval` mode and to `False` in `train` mode. """ def __init__( self, in_structure: GraphStructure, update_limit: int, beta_1: float = 0.9, beta_2: float = 0.9, epsilon: float = 1e-6, use_running_average: bool = False, ): self.in_structure = in_structure self.update_limit = update_limit self.use_running_average = use_running_average self.epsilon = epsilon self.beta_1 = beta_1 self.beta_2 = beta_2 self.module_dict = self._build_module_dict() def _build_module_dict(self) -> dict[str, HyperEdgeSetCenterReduceNormalizer]: """Creates a Center Reduce Normalizer module for each edge key in the graph structure.""" module_dict = {} for key, hyper_edge_set_structure in self.in_structure.hyper_edge_sets.items(): if hyper_edge_set_structure.feature_list is not None and len(hyper_edge_set_structure.feature_list) > 0: in_size = len(hyper_edge_set_structure.feature_list) module_dict[key] = HyperEdgeSetCenterReduceNormalizer( in_size, update_limit=self.update_limit, beta_1=self.beta_1, beta_2=self.beta_2, epsilon=self.epsilon, use_running_average=self.use_running_average, ) return nnx.data(module_dict) def __call__(self, *, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict]: """ Apply normalization to hyper-edge sets within a JaxGraph context using HyperEdgeSetCenterReduceNormalizer. This method normalizes the hyper-edge sets' feature arrays and updates the associated context graph accordingly. :param graph: JaxGraph representing the graph structure containing hyper-edge sets with feature arrays to be normalized. :param get_info: Boolean flag that indicates whether to return additional information about input and output graphs. :return: A tuple containing the normalized JaxGraph and an optional dictionary holding quantile information about the input and output graphs. """ hyper_edge_set_norm_dict = { k: (hyper_edge_set, self.module_dict[k]) for k, hyper_edge_set in graph.hyper_edge_sets.items() if k in self.module_dict.keys() } def apply_norm(edge_norm: tuple[JaxHyperEdgeSet, HyperEdgeSetCenterReduceNormalizer]) -> JaxHyperEdgeSet: hyper_edge_set, normalizer = edge_norm array = hyper_edge_set.feature_array if hyper_edge_set.feature_array is not None: if hyper_edge_set.feature_array.shape[-2] > 0: array = normalizer(array, jnp.expand_dims(hyper_edge_set.non_fictitious, -1)) return JaxHyperEdgeSet( feature_array=array, feature_names=hyper_edge_set.feature_names, non_fictitious=hyper_edge_set.non_fictitious, port_dict=hyper_edge_set.port_dict, ) normalized_hyper_edge_sets = jax.tree.map( apply_norm, hyper_edge_set_norm_dict, is_leaf=(lambda x: isinstance(x, tuple)) ) normalized_context = JaxGraph( hyper_edge_sets=normalized_hyper_edge_sets, non_fictitious_addresses=graph.non_fictitious_addresses, true_shape=graph.true_shape, current_shape=graph.current_shape, ) if get_info: info = {"input_graph": graph.quantiles(), "output_graph": normalized_context.quantiles()} else: info = {} return normalized_context, info