Source code for energnn.model.normalizer.tdigest_normalizer

# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

from functools import partial
from typing import Any, Sequence

import jax
import jax.numpy as jnp
import numpy as np
from fastdigest import TDigest
from flax import nnx
from jax import ShapeDtypeStruct
from jax.experimental import io_callback

from energnn.graph import GraphStructure
from energnn.graph.jax import JaxGraph, JaxHyperEdgeSet
from .normalizer import Normalizer


def _merge_equal_quantiles_host(p: np.ndarray, q: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    Resolves equal-quantile conflicts by averaging probabilities for identical quantile values.

    When some adjacent quantiles in `q` are equal (zero slope), this function
    computes a merged probability vector per feature so the piecewise-linear CDF
    remains bijective (strictly monotone in the probability coordinate). The
    algorithm groups identical quantile values per column and averages the
    corresponding probabilities.

    :param p: Probability grid as a 2-D array of shape (K, F) (values in [0,1]).
    :param q: Quantiles matrix of shape (K, F).
        Each column q[:, f] contains quantiles for feature f.
    :return: Tuple (p_merged, q_merged) where both have shape (K, F) and p_merged
        contains the merged/averaged probabilities per unique quantile value
        per feature, and q_merged is equal to `q` (cast to float32).
    """
    K, F = q.shape
    p_out = np.zeros((K, F), dtype=np.float32)
    q_out = q.astype(np.float32)
    for f in range(F):
        qf = q_out[:, f]
        pf = p[:, f]
        vals, inv, counts = np.unique(qf, return_inverse=True, return_counts=True)
        sum_p_per_unique = np.zeros_like(vals, dtype=np.float64)
        np.add.at(sum_p_per_unique, inv, pf)
        avg_p_per_unique = sum_p_per_unique / counts
        p_out[:, f] = avg_p_per_unique[inv].astype(np.float32)
    return p_out, q_out


def _ingest_new_data(
    max_centroids: Sequence[int],
    min_val: Sequence[float],
    max_val: Sequence[float],
    centroids_m: np.ndarray,
    centroids_c: np.ndarray,
    fp: np.ndarray,
    xp: np.ndarray,
    array: np.ndarray,
    mask: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Host-side callback to update T-Digest statistics using the fastdigest library.

    This function processes a batch of data, updates the underlying T-Digest structures,
    and extracts new interpolation points (xp, fp) for the normalization.

    :param max_centroids: Maximum number of centroids allowed for each feature.
    :param min_val: Current minimum value for each feature.
    :param max_val: Current maximum value for each feature.
    :param centroids_m: Centroid means for each feature.
    :param centroids_c: Centroid counts for each feature.
    :param fp: Interpolation probabilities (mapped to [-1, 1]).
    :param xp: Interpolation quantiles.
    :param array: New data batch of shape (N, F) or (B, N, F).
    :param mask: Mask for valid data.
    :return: Updated state variables for the TDigest modules.
    """
    # Variables are individual numpy arrays
    # array has shape (N, F) or (B, N, F)
    # mask has shape (N, 1) or (B, N, 1)

    if array.ndim == 3:
        # Batched: (B, N, F)
        B, N, F = array.shape
        array = array.reshape(B * N, F)
        mask = mask.reshape(B * N, 1)
    else:
        N, F = array.shape

    # Apply mask
    mask = mask.flatten().astype(bool)
    array = array[mask]

    n_features = array.shape[-1]
    K = fp.shape[0]

    new_max_centroids = []
    new_min_list = []
    new_max_list = []
    new_c_c_list = []
    new_c_m_list = []
    new_p_matrix = []
    new_q_matrix = []

    for i in range(n_features):
        feature_array = array[:, i]
        _max_centroids = int(max_centroids[i])
        _min = float(min_val[i])
        _max = float(max_val[i])
        _c_m = centroids_m[:, i]
        _c_c = centroids_c[:, i]

        tdigest_dict = {
            "max_centroids": _max_centroids,
            "min": _min,
            "max": _max,
            "centroids": [{"m": float(m), "c": float(c)} for m, c in zip(_c_m, _c_c) if c > 0.0],
        }

        # Handle NaNs from initialization
        if np.isnan(tdigest_dict["min"]):
            tdigest_dict["min"] = 0.0
        if np.isnan(tdigest_dict["max"]):
            tdigest_dict["max"] = 0.0

        tdigest = TDigest.from_dict(tdigest_dict)
        if len(feature_array) > 0:
            tdigest.batch_update(np.array(feature_array))

        new_tdigest_dict = tdigest.to_dict()

        new_max_centroids.append(new_tdigest_dict["max_centroids"])
        new_min_list.append(new_tdigest_dict["min"])
        new_max_list.append(new_tdigest_dict["max"])

        c_c = np.array([centroid["c"] for centroid in new_tdigest_dict["centroids"]])
        c_c = np.pad(c_c, (0, _max_centroids - len(c_c)), mode="constant", constant_values=0.0)
        new_c_c_list.append(c_c)
        c_m = np.array([centroid["m"] for centroid in new_tdigest_dict["centroids"]])
        c_m = np.pad(c_m, (0, _max_centroids - len(c_m)), mode="constant", constant_values=0.0)
        new_c_m_list.append(c_m)

        p_list = np.linspace(0, 1, K)
        q_list = [tdigest.quantile(float(p)) for p in p_list]
        new_p_matrix.append(p_list)
        new_q_matrix.append(np.asarray(q_list))

    new_p_matrix = np.stack(new_p_matrix, axis=0)  # (F, K)
    new_q_matrix = np.stack(new_q_matrix, axis=0)  # (F, K)

    p_merged, q_merged = _merge_equal_quantiles_host(new_p_matrix.T, new_q_matrix.T)
    new_xp = q_merged.astype(np.float32)
    new_fp = (-1.0 + 2.0 * p_merged).astype(np.float32)

    return (
        np.array(new_max_centroids, dtype=np.int32),
        np.array(new_min_list, dtype=np.float32),
        np.array(new_max_list, dtype=np.float32),
        np.stack(new_c_m_list, axis=1).astype(np.float32),
        np.stack(new_c_c_list, axis=1).astype(np.float32),
        new_fp,
        new_xp,
    )


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6))
def _tdigest_apply(
    array: jax.Array,
    non_fictitious: jax.Array,
    should_update: jax.Array,
    module_state: tuple[jax.Array, ...],
    in_size: int,
    max_centroids: int,
    n_breakpoints: int,
) -> tuple[jax.Array, ...]:
    """
    Applies normalization and optionally updates T-Digest statistics.
    This function is wrapped in custom_vjp to handle the non-differentiable io_callback.
    """
    (
        max_centroids_val,
        min_val,
        max_val,
        centroids_m,
        centroids_c,
        fp,
        xp,
    ) = module_state

    def update_fn(array: jax.Array, non_fictitious: jax.Array) -> tuple[jax.Array, ...]:
        result_shapes = (
            ShapeDtypeStruct((in_size,), jnp.int32),  # max_centroids
            ShapeDtypeStruct((in_size,), jnp.float32),  # min
            ShapeDtypeStruct((in_size,), jnp.float32),  # max
            ShapeDtypeStruct((max_centroids, in_size), jnp.float32),  # centroids_m
            ShapeDtypeStruct((max_centroids, in_size), jnp.float32),  # centroids_c
            ShapeDtypeStruct((n_breakpoints, in_size), jnp.float32),  # fp
            ShapeDtypeStruct((n_breakpoints, in_size), jnp.float32),  # xp
        )

        return io_callback(
            _ingest_new_data,
            result_shapes,
            max_centroids_val,
            min_val,
            max_val,
            centroids_m,
            centroids_c,
            fp,
            xp,
            array,
            non_fictitious,
        )

    new_vars = jax.lax.cond(
        should_update,
        lambda a, m: update_fn(a, m),
        lambda a, m: (max_centroids_val, min_val, max_val, centroids_m, centroids_c, fp, xp),
        array,
        non_fictitious,
    )
    return new_vars


def _tdigest_apply_fwd(
    array: jax.Array,
    non_fictitious: jax.Array,
    should_update: jax.Array,
    module_state: tuple[jax.Array, ...],
    in_size: int,
    max_centroids: int,
    n_breakpoints: int,
) -> tuple[tuple[jax.Array, ...], tuple[jax.Array, ...]]:
    new_vars = _tdigest_apply(array, non_fictitious, should_update, module_state, in_size, max_centroids, n_breakpoints)
    return (new_vars), (array, non_fictitious, new_vars[6], new_vars[5])


def _tdigest_apply_bwd(
    in_size: int, max_centroids: int, n_breakpoints: int, res: tuple[jax.Array, ...], grads: Any
) -> tuple[jax.Array, ...]:
    array, non_fictitious, xp, fp = res
    return 0 * array, None, None, None


_tdigest_apply.defvjp(_tdigest_apply_fwd, _tdigest_apply_bwd)


class TDigestModule(nnx.Module):
    """
    Maintains and applies T-Digest normalization for a set of features.

    This module uses the T-Digest algorithm to estimate quantiles and map input
    features to a target distribution (piecewise linear interpolation).
    It supports batch updates via an IO callback and provides a fast inference path.
    """

    def __init__(
        self,
        in_size: int,
        update_limit: int,
        n_breakpoints: int,
        max_centroids: int,
        use_running_average: bool,
    ):
        """
        Initializes the TDigestModule.

        :param in_size: Number of features to normalize.
        :param update_limit: Maximum number of update steps allowed.
        :param n_breakpoints: Number of points for the interpolation grid.
        :param max_centroids: Maximum number of centroids for the T-Digest.
        :param use_running_average: If True, skips updates and uses current state (inference mode).
        """
        self.in_size = in_size
        self.update_limit = update_limit
        self.n_breakpoints = n_breakpoints
        self.max_centroids = max_centroids
        self.use_running_average = use_running_average

        self.updates = nnx.Variable(jnp.array([0], dtype=jnp.int32))

        self.max_centroids_var = nnx.Variable(jnp.array([self.max_centroids] * self.in_size, dtype=jnp.int32))
        self.min_var = nnx.Variable(jnp.array([jnp.nan] * self.in_size, dtype=jnp.float32))
        self.max_var = nnx.Variable(jnp.array([jnp.nan] * self.in_size, dtype=jnp.float32))
        self.centroids_m_var = nnx.Variable(jnp.zeros([self.max_centroids, self.in_size], dtype=jnp.float32))
        self.centroids_c_var = nnx.Variable(jnp.zeros([self.max_centroids, self.in_size], dtype=jnp.float32))
        self.fp_var = nnx.Variable(jnp.linspace(-1, 1, self.n_breakpoints)[:, None] + jnp.zeros([1, self.in_size]))
        self.xp_var = nnx.Variable(jnp.linspace(-1, 1, self.n_breakpoints)[:, None] + jnp.zeros([1, self.in_size]))

    def __call__(self, array: jax.Array, non_fictitious: jax.Array) -> jax.Array:
        """
        Normalizes the input array using the current T-Digest state.

        If in training mode and under the update limit, it also triggers a state update.

        :param array: Input array of shape (..., in_size).
        :param non_fictitious: Mask for valid (non-fictitious) items.
        :return: Normalized array of the same shape as input.
        """
        is_training = not self.use_running_average
        should_update = is_training & (self.updates[...] < self.update_limit)[0]

        if is_training:
            module_state = (
                self.max_centroids_var[...],
                self.min_var[...],
                self.max_var[...],
                self.centroids_m_var[...],
                self.centroids_c_var[...],
                self.fp_var[...],
                self.xp_var[...],
            )

            new_vars = _tdigest_apply(
                array,
                non_fictitious,
                should_update,
                module_state,
                self.in_size,
                self.max_centroids,
                self.n_breakpoints,
            )

            # Update state variables (side effects)
            self.updates[...] = jnp.where(should_update, self.updates[...] + 1, self.updates[...])
            self.max_centroids_var[...] = jax.lax.stop_gradient(new_vars[0])
            self.min_var[...] = jax.lax.stop_gradient(new_vars[1])
            self.max_var[...] = jax.lax.stop_gradient(new_vars[2])
            self.centroids_m_var[...] = jax.lax.stop_gradient(new_vars[3])
            self.centroids_c_var[...] = jax.lax.stop_gradient(new_vars[4])
            self.fp_var[...] = jax.lax.stop_gradient(new_vars[5])
            self.xp_var[...] = jax.lax.stop_gradient(new_vars[6])

        xp = self.xp_var[...]
        fp = self.fp_var[...]

        def forward_local(x_feat, xp_feat, fp_feat):
            EPS = 1e-6
            interp_term = jnp.interp(x_feat, xp_feat, fp_feat)
            left_term = (
                jnp.minimum(x_feat - xp_feat[0], 0.0) * (fp_feat[1] - fp_feat[0] + EPS) / (xp_feat[1] - xp_feat[0] + EPS)
            )
            right_term = (
                jnp.maximum(x_feat - xp_feat[-1], 0.0) * (fp_feat[-1] - fp_feat[-2] + EPS) / (xp_feat[-1] - xp_feat[-2] + EPS)
            )
            return interp_term + left_term + right_term

        if array.ndim == 3:
            out = jax.vmap(
                lambda a: jax.vmap(forward_local, in_axes=(1, 1, 1), out_axes=1)(a, xp, fp),
                in_axes=0,
                out_axes=0,
            )(array)
        else:
            out = jax.vmap(forward_local, in_axes=(1, 1, 1), out_axes=1)(array, xp, fp)

        out = out * non_fictitious
        return out


[docs] class TDigestNormalizer(Normalizer): """ Graph-level normalizer that maintains a TDigestModule for each hyper-edge set type. This normalizer uses T-Digests to map feature distributions to a target grid (usually [-1, 1]), providing a non-parametric alternative to standard normalization. """ def __init__( self, in_structure: GraphStructure, update_limit: int, n_breakpoints: int = 20, max_centroids: int = 1000, use_running_average: bool = False, ): """ Initializes the TDigestNormalizer. :param in_structure: Structure of the input graph. :param update_limit: Maximum number of updates allowed for the T-Digests. :param n_breakpoints: Number of breakpoints for the interpolation grid. :param max_centroids: Maximum number of centroids for each T-Digest. :param use_running_average: Initial state for the running average flag. """ self.in_structure = in_structure self.update_limit = update_limit self.n_breakpoints = n_breakpoints self.max_centroids = max_centroids self.use_running_average = use_running_average self.module_dict = self._build_module_dict() def _build_module_dict(self) -> dict[str, dict[str, TDigestModule]]: """Creates a TDigest module for each hyper-edge set key in the graph structure.""" module_dict = {} for key, hyper_edge_set_structure in self.in_structure.hyper_edge_sets.items(): if hyper_edge_set_structure.feature_list is not None: in_size = len(hyper_edge_set_structure.feature_list) module_dict[key] = TDigestModule( in_size=in_size, update_limit=self.update_limit, n_breakpoints=self.n_breakpoints, max_centroids=self.max_centroids, use_running_average=self.use_running_average, ) else: module_dict[key] = None return nnx.data(module_dict) def set_running_average(self, use: bool): """ Sets the running average flag for the normalizer and all its sub-modules. :param use: If True, enables inference mode (no updates). """ self.use_running_average = use # module_dict is wrapped in nnx.data for module in self.module_dict.values(): module.use_running_average = use def __call__(self, *, graph: JaxGraph, get_info: bool = False) -> tuple[JaxGraph, dict]: """ Apply normalization to hyper-edge sets within a JaxGraph context using TDigest modules. This method normalizes the hyper-edge sets' feature arrays and updates the associated context graph accordingly. :param graph: JaxGraph representing the graph structure containing hyper-edge sets with feature arrays to be normalized. :param get_info: Boolean flag that indicates whether to return additional information about input and output graphs. :return: A tuple containing the normalized JaxGraph and an optional dictionary holding quantile information about the input and output graphs. """ hyper_edge_set_norm_dict = { k: (hyper_edge_set, self.module_dict[k]) for k, hyper_edge_set in graph.hyper_edge_sets.items() if k in self.module_dict.keys() } def apply_norm(edge_norm: tuple[JaxHyperEdgeSet, TDigestModule]) -> JaxHyperEdgeSet: hyper_edge_set, normalizer = edge_norm array = hyper_edge_set.feature_array if hyper_edge_set.feature_array is not None: if hyper_edge_set.feature_array.shape[-2] > 0: array = normalizer(array, jnp.expand_dims(hyper_edge_set.non_fictitious, -1)) return JaxHyperEdgeSet( feature_array=array, feature_names=hyper_edge_set.feature_names, non_fictitious=hyper_edge_set.non_fictitious, port_dict=hyper_edge_set.port_dict, ) normalized_hyper_edge_sets = jax.tree.map( apply_norm, hyper_edge_set_norm_dict, is_leaf=(lambda x: isinstance(x, tuple)) ) normalized_context = JaxGraph( hyper_edge_sets=normalized_hyper_edge_sets, non_fictitious_addresses=graph.non_fictitious_addresses, true_shape=graph.true_shape, current_shape=graph.current_shape, ) if get_info: info = {"input_graph": graph.quantiles(), "output_graph": normalized_context.quantiles()} else: info = {} return normalized_context, info