# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from typing import Any
import numpy as np
import pandas as pd
from energnn.graph.utils import to_numpy
FEATURE_ARRAY = "feature_array"
FEATURE_NAMES = "feature_names"
PORT_DICT = "port_dict"
NON_FICTITIOUS = "non_fictitious"
[docs]
class HyperEdgeSet(dict):
"""
A collection of hyper-edges of the same class, optionally batched.
Internally this is just a dict storing four entries.
:param port_dict: Mapping from a port name to an array of shape `(n_edges,)` or `(batch, n_edges)`.
:param feature_array: Array that contains all hyper-edge features.
:param feature_names: Dictionary from feature names to index in `feature_array`.
:param non_fictitious: Mask array set to 1 for non-fictitious objects and to 0 for fictitious objects.
"""
def __init__(
self,
*,
port_dict: dict[str, np.ndarray] | None,
feature_array: np.ndarray | None,
feature_names: dict[str, int] | None,
non_fictitious: np.ndarray,
) -> None:
super().__init__()
self[PORT_DICT] = port_dict
self[FEATURE_ARRAY] = feature_array
self[FEATURE_NAMES] = feature_names
self[NON_FICTITIOUS] = non_fictitious
[docs]
@classmethod
def from_dict(
cls,
*,
port_dict: dict[str, Any] | None = None,
feature_dict: dict[str, Any] | None = None,
) -> HyperEdgeSet:
"""
Build a HyperEdgeSet from raw dicts of ports and features.
Both inputs may be None, in which case the corresponding properties
are set to None and only `non_fictitious` of length zero is created.
:param port_dict: Dictionary of ports, each key corresponds to a port name and to the values are the
corresponding addresses for each object stored into an array.
:param feature_dict: Dictionary of features, each key corresponds to a feature name and to the values are the
corresponding features for each object stored into an array.
:returns: A properly structured `HyperEdgeSet` instance.
:raises ValueError: If ports or features contain NaNs or if shapes mismatch.
"""
# Convert inputs to pure numpy arrays / dicts
port_dict = check_dict_or_none(to_numpy(port_dict))
feature_dict = check_dict_or_none(to_numpy(feature_dict))
check_valid_ports(port_dict)
check_no_nan(port_dict=port_dict, feature_dict=feature_dict)
# Build feature_names and feature_array
if feature_dict is not None:
feature_names = {name: idx for idx, name in enumerate(sorted(feature_dict))}
feature_array = dict2array(feature_dict)
else:
feature_names, feature_array = None, None
# Build a non-fictitious mask.
shape = build_hyper_edge_set_shape(port_dict=port_dict, feature_dict=feature_dict)
non_fictitious = np.ones(int(shape))
return cls(
port_dict=port_dict,
feature_array=feature_array,
feature_names=feature_names,
non_fictitious=non_fictitious,
)
def __str__(self) -> str:
"""
Render the HyperEdgeSet as a pandas DataFrame string.
If `is_single`, uses a single-level index:
object_id
If `is_batch`, uses two-level index:
batch_id, object_id
:returns:
String representation of a `pandas.DataFrame`.
:raises ValueError:
If the internal array has unexpected dimensions.
"""
if self.is_single:
index = pd.MultiIndex.from_product([range(self.n_obj)], names=["object_id"])
elif self.is_batch:
index = pd.MultiIndex.from_product(
[range(self.n_batch), range(self.n_obj)],
names=["batch_id", "object_id"],
)
else:
raise ValueError("HyperEdgeSet is neither single nor batched.")
d = {}
if self.port_names is not None:
for k, v in sorted(self.port_dict.items()):
d[("ports", k)] = v.reshape([-1])
if self.feature_names is not None:
for k, v in sorted(self.feature_dict.items()):
d[("features", k)] = v.reshape([-1])
return pd.DataFrame(d, index=index).__str__()
@property
def array(self) -> np.ndarray:
"""
Concatenate (features, ports) along the last axis.
:returns:
Combined array of shape
- single: `(n_obj, n_feats + n_ports)`
- batch: `(batch, n_obj, n_feats + n_ports)`
"""
array = []
if self.feature_array is not None:
array.append(self.feature_array)
if self.port_array is not None:
array.append(self.port_array)
return np.concatenate(array, axis=-1)
@property
def is_batch(self) -> bool:
"""
True if `array` is 3-D: `(batch, n_obj, features+ports)`.
"""
return len(self.array.shape) == 3
@property
def is_single(self) -> bool:
"""
True if `array` is 2-D: `(n_obj, features+ports)`.
"""
return len(self.array.shape) == 2
@property
def n_obj(self) -> int:
"""
Number of hyper-edges (objects) per instance.
"""
if self.is_single:
return int(self.array.shape[0])
elif self.is_batch:
return int(self.array.shape[1])
else:
raise ValueError("HyperEdgeSet is neither single nor batched.")
@property
def n_batch(self) -> int:
"""
Number of batches. Only valid if `is_batch` is True.
:raises ValueError: If not a batch.
"""
if self.is_batch:
return int(self.array.shape[0])
else:
raise ValueError("HyperEdgeSet is not batched.")
@property
def feature_array(self) -> np.ndarray | None:
return self[FEATURE_ARRAY]
@feature_array.setter
def feature_array(self, value: np.ndarray) -> None:
self[FEATURE_ARRAY] = value
@property
def feature_names(self) -> dict[str, np.ndarray] | None:
return self[FEATURE_NAMES]
@property
def port_array(self) -> np.ndarray | None:
"""
Returns the stacked array of ports, of shape `(n_obj, n_ports)` or `(batch, n_obj, n_ports)`.
"""
if self.port_dict is None:
return None
return dict2array(self.port_dict)
@property
def port_names(self) -> dict[str, np.ndarray] | None:
"""
Maps a port name to a column index in `port_array`.
"""
if self.port_dict is None:
return None
return {k: np.array(idx) for idx, k in enumerate(sorted(self.port_dict.keys()))}
@property
def port_dict(self) -> dict[str, np.ndarray] | None:
return self[PORT_DICT]
@port_dict.setter
def port_dict(self, value: dict[str, np.ndarray] | None) -> None:
self[PORT_DICT] = value
@property
def non_fictitious(self) -> np.ndarray:
"""
Mask of shape `(n_obj,)` or `(batch, n_obj)`.
1 = real hyper-edge, 0 = padded/fictitious.
"""
return self[NON_FICTITIOUS]
@non_fictitious.setter
def non_fictitious(self, value: np.ndarray) -> None:
self[NON_FICTITIOUS] = value
@property
def feature_dict(self) -> dict[str, np.ndarray] | None:
"""
Unstack `feature_array` into a dict: feature_name --> array.
:returns: Dict of shape `(n_obj,)` or `(batch, n_obj)` per feature.
"""
if not self.feature_names:
return None
result = dict()
for k, v in self.feature_names.items():
# The last axis holds features
if self.is_batch:
result[k] = self.feature_array[..., np.array(v[0], int)]
else:
result[k] = self.feature_array[..., np.array(v, int)]
return result
@property
def feature_flat_array(self) -> np.ndarray | None:
"""
Flatten all features into one long vector per `(batch, )` by Fortran ordering.
:returns:
Single instance: 1D array of length `n_obj * n_feats`.
Batched instance: 2D array of shape `(batch, n_obj * n_feats)`.
"""
if self.feature_array is None:
return None
shape = [self.n_batch, -1] if self.is_batch else -1
return self.feature_array.reshape(shape, order="F")
@feature_flat_array.setter
def feature_flat_array(self, array: np.ndarray) -> None:
"""
Update the feature array from a flat Fortran-ordered array.
:param array: Must match the shape of current `.feature_flat_array`.
:raises ValueError: If shapes mismatch.
"""
flat = self.feature_flat_array
if flat is None or flat.shape != array.shape:
raise ValueError("Shape mismatch for feature_flat_array setter.")
if self.feature_names is not None:
if self.is_single:
self.feature_array = array.reshape([self.n_obj, -1], order="F")
elif self.is_batch:
self.feature_array = array.reshape([self.n_batch, self.n_obj, -1], order="F")
[docs]
def pad(self, target_shape: np.ndarray | int) -> None:
"""
Pad a *single* HyperEdgeSet with a series of zeros for features and max-int for ports
so that shapes match the `target_shape`.
:param target_shape: Desired n_obj after padding; must be ≥ current n_obj.
:raises ValueError: If called on a batch or if target_shape < current n_obj.
"""
if not self.is_single:
raise ValueError("HyperEdgeSet is batched, impossible to pad.")
old_n_obj = self.n_obj
if old_n_obj > target_shape:
raise ValueError("Provided target_shape is smaller than current shape, padding is impossible! ")
# Pad features
if self.feature_array is not None:
self.feature_array = np.pad(self.feature_array, [(0, int(target_shape) - old_n_obj), (0, 0)])
# Pad ports
if self.port_dict is not None:
for k, v in self.port_dict.items():
self.port_dict[k] = np.pad(v, [0, int(target_shape) - old_n_obj])
# Pad fictitious mask
if self.non_fictitious is not None:
self.non_fictitious = np.pad(self.non_fictitious, [0, int(target_shape) - old_n_obj])
[docs]
def unpad(self, target_shape: np.ndarray | int) -> None:
"""
Remove all objects beyond the index `target` in a *single* HyperEdgeSet.
:param target_shape: New n_obj; must be ≤ current n_obj.
:raises ValueError: If called on a batch or if target_shape > current n_obj.
"""
if not self.is_single:
raise ValueError("HyperEdgeSet is batched, impossible to unpad.")
if self.n_obj < target_shape:
raise ValueError("Provided target_shape is higher than current shape, unpadding is impossible! ")
# Unpad features
if self.feature_array is not None:
self.feature_array = self.feature_array[: int(target_shape)]
# Unpad ports
if self.port_dict is not None:
for k, v in self.port_dict.items():
self.port_dict[k] = v[: int(target_shape)]
# Unpad fictitious mask
if self.non_fictitious is not None:
self.non_fictitious = self.non_fictitious[: int(target_shape)]
[docs]
def offset_addresses(self, offset: np.ndarray | int) -> None:
"""Adds an offset on all addresses. Should only be used before graph concatenation.
:param offset: Scalar or array to add to each address array.
"""
self.port_dict = {k: a + np.array(offset) for k, a in self.port_dict.items()}
[docs]
def collate_hyper_edge_sets(hyper_edge_set_list: list[HyperEdgeSet]) -> HyperEdgeSet:
"""
Collate a list of HyperEdgeSet into a single batched HyperEdgeSet.
Each HyperEdgeSet in the input list is assumed to have the same feature and port schema.
This function stacks the per-edge attributes along the 0-th axis.
:param hyper_edge_set_list: Sequence of HyperEdgeSet objects to batch together. Must be non-empty.
:return: A single batched HyperEdgeSet.
:raises IndexError: Raised if `hyper_edge_set_list` is empty.
:raises ValueError: Raised if not all HyperEdgeSet share the same keys in port_names or feature_names.
"""
if not hyper_edge_set_list:
raise IndexError("collate_edges requires at least one Edge to collate.")
first_hyper_edge_set = hyper_edge_set_list[0]
# Check the consistency of keys
for e in hyper_edge_set_list[1:]:
_check_keys_consistency(first_hyper_edge_set, e)
# Collate feature arrays
if first_hyper_edge_set.feature_array is not None:
feature_array = np.stack([e.feature_array for e in hyper_edge_set_list], axis=0)
else:
feature_array = None
# Collate feature names
if first_hyper_edge_set.feature_names is not None:
feature_names = {
k: np.stack([e.feature_names[k] for e in hyper_edge_set_list]) for k in first_hyper_edge_set.feature_names
}
else:
feature_names = None
# Collate port dicts
if first_hyper_edge_set.port_dict is not None:
port_dict = {k: np.stack([e.port_dict[k] for e in hyper_edge_set_list]) for k in first_hyper_edge_set.port_dict}
else:
port_dict = None
# Collate non-fictitious masks
if first_hyper_edge_set.non_fictitious is not None:
non_fictitious = np.stack([e.non_fictitious for e in hyper_edge_set_list])
else:
non_fictitious = None
return HyperEdgeSet(
port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious
)
[docs]
def separate_hyper_edge_sets(hyper_edge_set_batch: HyperEdgeSet) -> list[HyperEdgeSet]:
"""
Separate a batched HyperEdgeSet into its constituent HyperEdgeSet instances.
The input HyperEdgeSet must have been created by :py:func:`collate_hyper_edge_sets` or otherwise
its property "array" must return a 3D array.
:param hyper_edge_set_batch: The batched HyperEdgeSet to unstack.
:return: List of HyperEdgeSet instances, each corresponding to one batch element.
:raises ValueError: If `hyper_edge_set_batch.is_batch` is False.
"""
if not hyper_edge_set_batch.is_batch:
raise ValueError("Input is not a batch, impossible to separate.")
if hyper_edge_set_batch.feature_array is not None:
feature_array_list = np.unstack(hyper_edge_set_batch.feature_array, axis=0)
else:
feature_array_list = [None] * hyper_edge_set_batch.n_batch
if hyper_edge_set_batch.feature_names is not None:
a = {k: np.unstack(hyper_edge_set_batch.feature_names[k]) for k in hyper_edge_set_batch.feature_names}
feature_names_list = [dict(zip(a, t)) for t in zip(*a.values())]
else:
feature_names_list = [None] * hyper_edge_set_batch.n_batch
if hyper_edge_set_batch.port_dict is not None:
a = {k: np.unstack(hyper_edge_set_batch.port_dict[k]) for k in hyper_edge_set_batch.port_dict}
port_dict_list = [dict(zip(a, t)) for t in zip(*a.values())]
else:
port_dict_list = [None] * hyper_edge_set_batch.n_batch
if hyper_edge_set_batch.non_fictitious is not None:
non_fictitious_list = np.unstack(hyper_edge_set_batch.non_fictitious, axis=0)
else:
non_fictitious_list = [None] * hyper_edge_set_batch.n_batch
hyper_edge_set_list = []
for fa, fn, ad, nf in zip(feature_array_list, feature_names_list, port_dict_list, non_fictitious_list):
hyper_edge_set = HyperEdgeSet(port_dict=ad, feature_array=fa, feature_names=fn, non_fictitious=nf)
hyper_edge_set_list.append(hyper_edge_set)
return hyper_edge_set_list
[docs]
def concatenate_hyper_edge_sets(hyper_edge_set_list: list[HyperEdgeSet]) -> HyperEdgeSet:
"""
Concatenate several single HyperEdgeSet into one single HyperEdgeSet.
Unlike :py:func:`collate_hyper_edge_sets`, this does *not* create a batch dimension,
but simply stacks objects end-to-end.
:param hyper_edge_set_list: List of single (non-batched) HyperEdgeSet
:returns: One HyperEdgeSet with n_obj = sum of all inputs’ n_obj
"""
port_dict = {
k: np.concatenate([hes.port_dict[k] for hes in hyper_edge_set_list]) for k in hyper_edge_set_list[0].port_dict
}
feature_array = np.concatenate([hes.feature_array for hes in hyper_edge_set_list], axis=0)
feature_names = hyper_edge_set_list[0].feature_names
non_fictitious = np.concatenate([hes.non_fictitious for hes in hyper_edge_set_list])
return HyperEdgeSet(
port_dict=port_dict, feature_array=feature_array, feature_names=feature_names, non_fictitious=non_fictitious
)
[docs]
def check_dict_shape(*, d: dict[str, np.ndarray] | None, n_objects: int | None) -> int | None:
"""
Ensure all arrays in a dictionary have the same size on their last axis.
If `n_objects` is not provided, it is inferred from the first array’s last dimension.
Otherwise, every array’s last dimension must match the given `n_objects`.
:param d: Mapping from feature/port name to `numpy` array
where each array’s last axis is object-indexed.
:param n_objects: Optional expected size of the last axis; if None, will be inferred.
:return: The validated or inferred `n_objects`.
:raises ValueError: If any array’s last dimension does not match `n_objects`.
"""
if d is not None:
if n_objects is None:
item: np.ndarray = next(iter(d.values()))
n_objects = item.shape[-1]
for name, arr in d.items():
if arr.shape[-1] != n_objects:
raise ValueError(f"Array for key '{name}' has last dimension {arr.shape[-1]}, expected {n_objects}.")
return n_objects
[docs]
def build_hyper_edge_set_shape(
*,
port_dict: dict[str, np.ndarray] | None,
feature_dict: dict[str, np.ndarray] | None,
) -> np.ndarray:
"""
Builds a numpy array representing the number of hyper-edges.
Validate that `port_dict` and `feature_dict` have consistent sizes
on their last dimensions and return a scalar numpy array containing that count.
:param port_dict: Mapping from port names to numpy arrays, or None.
:param feature_dict: Mapping of feature names to numpy arrays, or None.
:return: A scalar numpy array of dtype float32 with the number of objects.
:raises ValueError: If both inputs are None, or if their shapes conflict.
"""
if port_dict is None and feature_dict is None:
raise ValueError("At least one of port_dict or feature_dict must be provided.")
n_objects = check_dict_shape(d=port_dict, n_objects=None)
n_objects = check_dict_shape(d=feature_dict, n_objects=n_objects)
return np.array(n_objects, dtype=np.dtype("float32"))
[docs]
def dict2array(features_dict: dict[str, np.ndarray] | None) -> np.ndarray | None:
"""
Stack a dictionary of arrays into a single array along the last axis.
The arrays are stacked in alphabetical order of their dictionary keys.
:param features_dict: Mapping from a feature name to a `numpy` array, or None.
:return: A stacked array with an added last dimension for features, or None.
"""
if features_dict is None:
return None
return np.stack([features_dict[k] for k in sorted(features_dict)], axis=-1)
[docs]
def check_dict_or_none(_input: dict | np.ndarray | None) -> dict | None:
"""
Validate that the input is either a dict or None.
:param _input: Object to validate
:return: the input if it was a dict or None
:raises ValueError: if `_input` is neither dict nor None
"""
if isinstance(_input, dict):
return _input
if _input is None:
return None
raise ValueError(f"Expected dict or None, got {type(_input)}")
[docs]
def check_no_nan(
*,
port_dict: dict[str, np.ndarray] | None,
feature_dict: dict[str, np.ndarray] | None,
) -> None:
"""
Ensure there are no NaN values in port or feature arrays.
:param port_dict: Mapping from port names to arrays, or None.
:param feature_dict: Mapping of feature names to arrays, or None.
:raises ValueError: If any array contains NaN.
"""
for name, arr in (port_dict or {}).items():
if np.any(np.isnan(arr)):
raise ValueError(f"NaN detected in port array for key '{name}'.")
for name, arr in (feature_dict or {}).items():
if np.any(np.isnan(arr)):
raise ValueError(f"NaN detected in feature array for key '{name}'.")
def check_valid_ports(port_dict: dict[str, np.ndarray] | None) -> None:
"""
Ensure that ports map only to integer-valued addresses.
:param port_dict: Mapping from port names to arrays, or None.
:raises ValueError: If any port array has entries that are not integer.
"""
for name, arr in (port_dict or {}).items():
if not np.allclose(arr, np.int32(arr)):
raise ValueError(f"Non-integer values detected in port array for key '{name}'.")
def _check_keys_consistency(hes_1, hes_2):
if (hes_1.port_names is None) != (hes_2.port_names is None):
raise ValueError("Mismatch in presence of port_names among hyper-edge sets.")
if (hes_1.feature_names is None) != (hes_2.feature_names is None):
raise ValueError("Mismatch in presence of feature_names among hyper-edge sets.")
if hes_1.port_names and hes_1.port_names.keys() != hes_2.port_names.keys():
raise ValueError("Inconsistent port_names keys among hyper-edge sets.")
if hes_1.feature_names and hes_1.feature_names.keys() != hes_2.feature_names.keys():
raise ValueError("Inconsistent feature_names keys among hyper-edge sets.")