# Copyright (c) 2025, RTE (http://www.rte-france.com)
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
import pickle as pkl
import numpy as np
from energnn.graph.hyper_edge_set import (
HyperEdgeSet,
collate_hyper_edge_sets,
concatenate_hyper_edge_sets,
separate_hyper_edge_sets,
)
from energnn.graph.shape import GraphShape, collate_shapes, separate_shapes, sum_shapes
HYPER_EDGE_SETS = "hyper_edge_sets"
TRUE_SHAPE = "true_shape"
CURRENT_SHAPE = "current_shape"
NON_FICTITIOUS_ADDRESSES = "non_fictitious_addresses"
[docs]
class Graph(dict):
"""
Hyper Heterogeneous Multi-Graph (H2MG) container.
Stores hyper-edge sets, shapes, and address masks for single or batched graphs.
:param hyper_edge_sets: Dictionary of hyper-edge sets contained in the graph.
:param true_shape: True shape of the graph, not altered by padding.
:param current_shape: Current shape of the graph, consistent with padding.
:param non_fictitious_addresses: Mask filled with ones for real addresses, and zeros otherwise.
"""
def __init__(
self,
*,
hyper_edge_sets: dict[str, HyperEdgeSet],
true_shape: GraphShape,
current_shape: GraphShape,
non_fictitious_addresses: np.ndarray,
) -> None:
super().__init__()
self[HYPER_EDGE_SETS] = hyper_edge_sets
self[TRUE_SHAPE] = true_shape
self[CURRENT_SHAPE] = current_shape
self[NON_FICTITIOUS_ADDRESSES] = non_fictitious_addresses
[docs]
@classmethod
def from_dict(cls, *, hyper_edge_set_dict: dict[str, HyperEdgeSet], n_addresses: np.ndarray) -> Graph:
"""
Builds a graph from a dictionary of :class:`energnn.graph.HyperEdgeSet` and a registry.
:param hyper_edge_set_dict: Dictionary of hyper-edge sets contained in the graph.
:param n_addresses: Number of unique addresses that appear in all the hyper-edge sets.
:return: Graph that contains both the hyper-edge sets and the registry.
"""
non_fictitious_addresses = np.ones(shape=[n_addresses])
check_hyper_edge_set_dict_type(hyper_edge_set_dict)
check_valid_addresses(hyper_edge_set_dict, n_addresses)
true_shape = GraphShape.from_dict(hyper_edge_set_dict=hyper_edge_set_dict, non_fictitious=non_fictitious_addresses)
current_shape = true_shape
return cls(
hyper_edge_sets=hyper_edge_set_dict,
true_shape=true_shape,
current_shape=current_shape,
non_fictitious_addresses=non_fictitious_addresses,
)
@property
def true_shape(self) -> GraphShape:
"""
True shape of the graph with the real number of objects for each hyper-edge set
class as well as the size of the registry stored in a GraphShape object.
There is no setter for this property.
:return: A graph shape of true sizes.
"""
return self[TRUE_SHAPE]
@property
def current_shape(self) -> GraphShape:
"""
The current shape of the graph taking into accounts fake padding objects.
:return: A graph shape of current sizes.
"""
return self[CURRENT_SHAPE]
@current_shape.setter
def current_shape(self, value: GraphShape) -> None:
"""
Sets the current shape of the graph taking into accounts fake padding objects.
:param value: A new graph shape.
"""
self[CURRENT_SHAPE] = value
@property
def non_fictitious_addresses(self) -> np.ndarray:
"""
Mask filled with ones for real addresses, and zeros otherwise.
:return: Array filled with ones and zeros.
"""
return self[NON_FICTITIOUS_ADDRESSES]
@non_fictitious_addresses.setter
def non_fictitious_addresses(self, value: np.ndarray):
"""
Sets the address mask.
:param value: Array filled with ones and zeros.
"""
self[NON_FICTITIOUS_ADDRESSES] = value
@property
def hyper_edge_sets(self) -> dict[str, HyperEdgeSet]:
"""
Dictionary of hyper-edge sets.
:return: Dict of edge class to Edge.
"""
return self[HYPER_EDGE_SETS]
@hyper_edge_sets.setter
def hyper_edge_sets(self, hyper_edge_set_dict: dict[str, HyperEdgeSet]) -> None:
"""
Sets the dictionary of hyper-edge sets.
:param hyper_edge_set_dict: New dictionary of hyper-edge sets.
"""
self[HYPER_EDGE_SETS] = hyper_edge_set_dict
def __str__(self) -> str:
r = ""
for k, v in sorted(self.hyper_edge_sets.items()):
r += "{}\n{}\n".format(k, v)
return r
[docs]
def to_pickle(self, file_path: str) -> None:
"""Saves a graph as a pickle file.
:param file_path: Destination path
"""
with open(file_path, "wb") as handle:
pkl.dump(self, handle, protocol=pkl.HIGHEST_PROTOCOL)
[docs]
@classmethod
def from_pickle(cls, *, file_path: str) -> Graph:
"""Loads a graph from a pickle file.
:param file_path: Source path.
:return: Deserialized Graph.
"""
with open(file_path, "rb") as handle:
graph = pkl.load(handle)
return graph
@property
def is_batch(self) -> bool:
"""
Determines if the graph is batched.
:return: True if all hyper-edge sets are batched and if the non-fictitious mask is a 2-D array when defined.
"""
for k, e in self.hyper_edge_sets.items():
if not e.is_batch:
return False
if (self.non_fictitious_addresses is not None) and (len(self.non_fictitious_addresses.shape) != 2):
return False
else:
return True
@property
def is_single(self) -> bool:
"""
Determines if the graph is single.
:return: True if all hyper-edge sets are single and if the non-fictitious mask is a 1-D array when defined.
"""
for k, e in self.hyper_edge_sets.items():
if not e.is_single:
return False
if (self.non_fictitious_addresses is not None) and (len(self.non_fictitious_addresses.shape) != 1):
return False
else:
return True
@property
def feature_flat_array(self) -> np.ndarray:
"""
Returns an array that concatenates features of all hyper-edge sets.
:return: Numpy array of concatenated features.
:raises ValueError: If no hyper-edge set is present.
"""
values_list = []
# Gather hyper-edges set features
if self.hyper_edge_sets is not None:
for key, hyper_edge_set in sorted(self.hyper_edge_sets.items()):
if hyper_edge_set.feature_flat_array is not None:
values_list.append(hyper_edge_set.feature_flat_array)
else:
raise ValueError("This graph does not contain any hyper-edge set, and can't be cast as a flat array.")
return np.concatenate(values_list, axis=-1)
@feature_flat_array.setter
def feature_flat_array(self, value: np.ndarray) -> None:
"""
Updates the flat array contained in the H2MG.
:param value: Flat feature array.
:raises ValueError: If shapes do not match the current feature flat array.
"""
if np.any(self.feature_flat_array.shape != value.shape):
raise ValueError("Invalid array shape.")
i = 0
if self.hyper_edge_sets is not None:
for key, hyper_edge_set in sorted(self.hyper_edge_sets.items()):
if hyper_edge_set.feature_names is not None:
length = np.shape(hyper_edge_set.feature_flat_array)[-1]
if length > 0:
self.hyper_edge_sets[key].feature_flat_array = value[..., i : i + length] # Slice over the last axis
i += length
else:
raise ValueError("This graph does not contain any hyper-edge set, and can't be cast as a flat array.")
[docs]
def pad(self, target_shape: GraphShape) -> None:
"""
Pads hyper-edge sets and address mask to match target_shape.
:param target_shape: Desired GraphShape with larger dimensions.
:raises ValueError: If the graph is not single.
"""
if not self.is_single:
raise ValueError("This graph is not single and cannot be padded.")
for key, hyper_edge_set_shape in target_shape.hyper_edge_sets.items():
self.hyper_edge_sets[key].pad(hyper_edge_set_shape)
self.non_fictitious_addresses = np.pad(
self.non_fictitious_addresses, [0, int(target_shape.addresses) - int(self.current_shape.addresses)]
)
self.current_shape = target_shape
[docs]
def unpad(self) -> None:
"""
Removes padding to restore true_shape.
:raises ValueError: If the graph is not single.
"""
for key, hyper_edge_set_shape in self.true_shape.hyper_edge_sets.items():
self.hyper_edge_sets[key].unpad(hyper_edge_set_shape)
self.non_fictitious_addresses = self.non_fictitious_addresses[: int(self.true_shape.addresses)]
self.current_shape = self.true_shape
[docs]
def count_connected_components(self) -> tuple[int, np.ndarray]:
"""
Counts connected components, and the component id of each address.
:return: `(num_components, component_labels)`
:raises ValueError: If the graph is not single.
"""
def _max_propagate(*, graph: Graph, h_: np.ndarray) -> np.ndarray:
"""Propagates the max value of addresses through hyper-edges."""
import copy
h_new_ = copy.deepcopy(h_)
edge_h = {}
for edge_key, edge in graph.hyper_edge_sets.items():
edge_h[edge_key] = []
for address_key, address_array in edge.port_dict.items():
edge_h[edge_key].append(h_new_[address_array.astype(int)])
edge_h[edge_key] = np.stack(edge_h[edge_key], axis=0)
edge_h[edge_key] = np.max(edge_h[edge_key], axis=0)
for address_key, address_array in edge.port_dict.items():
new_val = np.max(
np.stack([edge_h[edge_key], h_new_[address_array.astype(int)]], axis=0),
axis=0,
)
np.maximum.at(h_new_, address_array.astype(int), new_val)
return h_new_
if not self.is_single:
raise ValueError("Graph is not single.")
h = np.arange(len(self.non_fictitious_addresses))
converged = False
while not converged:
h_new = _max_propagate(graph=self, h_=h)
converged = np.all(h_new == h)
h = h_new
u, indices = np.unique(h, return_inverse=True)
return len(u), indices
[docs]
def offset_addresses(self, offset: np.ndarray | int) -> None:
"""
Adds an offset on all addresses. Should only be used before graph concatenation.
:param offset: Integer or array to add to addresses
"""
for k, e in self.hyper_edge_sets.items():
e.offset_addresses(offset=offset)
[docs]
def quantiles(self, q_list: list[float] | None = None) -> dict[str, np.ndarray]:
"""Computes quantiles of hyper-edge set features.
:param q_list: Percentiles to compute
:return: Mapping "hyper_edge_set/feature/percentile" to values.
:raises ValueError: If the graph is not single or batched and cannot be quantiled.
"""
if q_list is None:
q_list = [0.0, 10.0, 25.0, 50.0, 75.0, 90.0, 100.0]
info = {}
for object_name, hyper_edge_sets in self.hyper_edge_sets.items():
if hyper_edge_sets.feature_dict is not None:
for feature_name, array in hyper_edge_sets.feature_dict.items():
if np.size(array) > 0:
for q in q_list:
if self.is_single:
value = np.nanpercentile(array, q=q)
elif self.is_batch:
value = np.nanpercentile(array, q=q, axis=1)
else:
raise ValueError("This graph is not single or batch and cannot be quantiled.")
info[f"{object_name}/{feature_name}/{q}th-percentile"] = value
return info
[docs]
def collate_graphs(graph_list: list[Graph]) -> Graph:
"""
Collate a list of Graphs into a single Graph with padded shapes.
All input graphs must share the same `current_shape`.
:param graph_list: List of Graph instances to collate.
:returns: A new Graph whose
- `true_shape` is the batch of all `true_shape`s.
- `current_shape` is the batch of all `current_shape`s (they must be identical).
- `hyper_edge_sets` are collated per hyper-edge set class.
- `non_fictitious_addresses` stacked along a new batch dimension.
:raises ValueError: If `graph_list` is an empty list.
:raises AssertionError: If the `current_shape` differs among inputs.
"""
if not graph_list:
raise ValueError("collate_graphs requires at least one Graph.")
first_graph = graph_list[0]
# Assert that all current shapes are equal
current_shape_list = [g.current_shape for g in graph_list]
current_shape = first_graph.current_shape
for s in current_shape_list:
assert s == current_shape
current_shape_batch = collate_shapes(current_shape_list)
true_shape_list = [g.true_shape for g in graph_list]
true_shape_batch = collate_shapes(true_shape_list)
hyper_edge_sets_batch = {}
for k in first_graph.hyper_edge_sets.keys():
hyper_edge_sets_batch[k] = collate_hyper_edge_sets([g.hyper_edge_sets[k] for g in graph_list])
if first_graph.non_fictitious_addresses is not None:
non_fictitious_addresses_batch = np.stack([g.non_fictitious_addresses for g in graph_list], axis=0)
else:
non_fictitious_addresses_batch = None
return Graph(
hyper_edge_sets=hyper_edge_sets_batch,
non_fictitious_addresses=non_fictitious_addresses_batch,
true_shape=true_shape_batch,
current_shape=current_shape_batch,
)
[docs]
def separate_graphs(graph_batch: Graph) -> list[Graph]:
"""
Split a batch of collated Graph into a list of single Graphs.
It reverses the operation of :py:func:`collate_graphs`.
:param graph_batch: A Graph whose `current_shape` and `true_shape` are batched.
:returns: List of Graphs, each corresponding to one element in the batch.
"""
current_shape_list = separate_shapes(graph_batch.current_shape)
true_shape_list = separate_shapes(graph_batch.true_shape)
n_batch = len(current_shape_list)
hyper_edge_set_list_dict = {}
for k in graph_batch.hyper_edge_sets.keys():
hyper_edge_set_list_dict[k] = separate_hyper_edge_sets(graph_batch.hyper_edge_sets[k])
if graph_batch.non_fictitious_addresses is not None:
non_fictitious_addresses_list = np.unstack(graph_batch.non_fictitious_addresses, axis=0)
else:
non_fictitious_addresses_list = [None] * n_batch
hyper_edge_set_dict_list = [
{k: hyper_edge_set_list_dict[k][i] for k in hyper_edge_set_list_dict.keys()} for i in range(n_batch)
]
graph_list = []
for e, n, t, c in zip(hyper_edge_set_dict_list, non_fictitious_addresses_list, true_shape_list, current_shape_list):
graph = Graph(hyper_edge_sets=e, non_fictitious_addresses=n, true_shape=t, current_shape=c)
graph_list.append(graph)
return graph_list
[docs]
def concatenate_graphs(graph_list: list[Graph]) -> Graph:
"""
Concatenates multiple graphs into a single graph.
This function merges a sequence of graphs by combining their non-fictitious addresses,
hyper-edge sets, and shapes into one unified Graph instance. Address offsets are temporarily applied
to avoid collisions between vertex indices during hyper-edge set concatenation, then reverted to preserve
the integrity of the original Graph objects.
:param graph_list: A list of Graph instances to be concatenated.
:return: A new Graph object representing the concatenation of all input graphs.
:raises ValueError: If `graph_list` is empty.
:note: The input graphs are temporarily modified to apply address offsets but are restored
to their original state before the function returns.
"""
if not graph_list:
raise ValueError("graph_list must contain at least one Graph")
n_addresses_list = [len(graph.non_fictitious_addresses) for graph in graph_list]
offset_list = [sum(n_addresses_list[:i]) for i in range(len(n_addresses_list))]
non_fictitious_addresses = np.concatenate([graph.non_fictitious_addresses for graph in graph_list], axis=0)
true_shape = sum_shapes([graph.true_shape for graph in graph_list])
current_shape = sum_shapes([graph.current_shape for graph in graph_list])
[graph.offset_addresses(offset=offset) for graph, offset in zip(graph_list, offset_list)]
hyper_edge_sets = {
k: concatenate_hyper_edge_sets([graph.hyper_edge_sets[k] for graph in graph_list])
for k in graph_list[0].hyper_edge_sets
}
[graph.offset_addresses(offset=-offset) for graph, offset in zip(graph_list, offset_list)]
return Graph(
hyper_edge_sets=hyper_edge_sets,
non_fictitious_addresses=non_fictitious_addresses,
true_shape=true_shape,
current_shape=current_shape,
)
[docs]
def check_hyper_edge_set_dict_type(hyper_edge_set_dict: dict[str, HyperEdgeSet]) -> None:
"""
Validate that the provided mapping is a dictionary of HyperEdgeSet instances.
:param hyper_edge_set_dict: A mapping from string keys to HyperEdgeSet objects.
:raises TypeError: If `hyper_edge_set_dict` is not a dictionary, or if any value in it is not an HyperEdgeSet.
"""
if not isinstance(hyper_edge_set_dict, dict):
raise TypeError("Provided 'hyper_edge_set_dict' is not a 'dict', but a {}.".format(type(hyper_edge_set_dict)))
for key, hyper_edge_set in hyper_edge_set_dict.items():
if not isinstance(hyper_edge_set, HyperEdgeSet):
raise TypeError("Item associated with '{}' key is not an 'hyper_edge_set_dict'.".format(key))
def check_valid_addresses(hyper_edge_set_dict: dict[str, HyperEdgeSet], n_addresses: np.ndarray) -> None:
"""
Ensure that all address indices in each HyperEdgeSet are valid with respect to the registry.
Iterates over all hyper-edge sets in `hyper_edge_set_dict` and, if a hyper-edge set defines `address_names`,
checks that its integer-coded addresses do not exceed the provided count array.
:param hyper_edge_set_dict: Mapping from hyper-edge set names to HyperEdgeSet objects containing address arrays.
:param n_addresses: 1D array where each entry gives the number of valid addresses
for the corresponding hyper-edge set.
:raises AssertionError: If any address in any hyper-edge set is outside the valid range
(i.e., not less than the corresponding entry in `n_addresses`).
"""
for key, hyper_edge_set in hyper_edge_set_dict.items():
if hyper_edge_set.port_names is not None:
assert np.all(hyper_edge_set.port_array < n_addresses)
[docs]
def get_statistics(graph: Graph, axis: int | None = None, norm_graph: Graph | None = None) -> dict:
"""
Extract summary statistics from each feature array in the graph's hyper-edge sets.
For every feature of every hyper-edge in `graph`, computes:
- Root Mean Squared Error (RMSE)
- Mean Absolute Error (MAE)
- First and second moments (mean, standard deviation)
- Range and quantiles (min, 10th, 25th, 50th, 75th, 90th, max)
If `norm_graph` is provided, then it also returns normalized metrics:
- Normalized RMSE (nrmse)
- Normalized MAE (nmae)
:param graph: Graph object containing hyper-edge sets with feature dictionaries.
:param axis: Axis along which to compute statistics. If None, statistics
are computed over the flattened array.
:param norm_graph: Optional Graph whose features serve as normalization reference.
:return: A dictionary mapping keys of the form
``"{hyper_edge_set_name}/{feature_name}/{stat}"`` to their computed values.
Values are floats or numpy arrays depending on `axis`.
"""
# Convert fictitious features to NaN.
for key, hyper_edge_set in graph.hyper_edge_sets.items():
mask = hyper_edge_set.non_fictitious
if hyper_edge_set.feature_array is not None:
graph.hyper_edge_sets[key].feature_array[mask == 0] = np.nan
info = {}
for object_name, hyper_edge_set in graph.hyper_edge_sets.items():
if hyper_edge_set.feature_dict is not None:
for feature_name, array in hyper_edge_set.feature_dict.items():
if array.size == 0:
if axis == 1:
array = np.array([[0.0]])
else:
array = np.array([0.0])
# Root Mean Squared Error
rmse = np.sqrt(np.nanmean(array**2, axis=axis))
info["{}/{}/rmse".format(object_name, feature_name)] = rmse
if norm_graph is not None:
norm_array = norm_graph.hyper_edge_sets[object_name].feature_dict[feature_name]
norm_array = norm_array - np.nanmean(norm_array)
nrmse = rmse / (np.sqrt(np.nanmean(norm_array**2, axis=axis)) + 1e-9)
info["{}/{}/nrmse".format(object_name, feature_name)] = nrmse
# Mean Absolute Error
mae = np.nanmean(np.abs(array), axis=axis)
info["{}/{}/mae".format(object_name, feature_name)] = mae
if norm_graph is not None:
norm_array = norm_graph.hyper_edge_sets[object_name].feature_dict[feature_name]
norm_array = norm_array - np.nanmean(norm_array)
nmae = mae / (np.nanmean(np.abs(norm_array), axis=axis) + 1e-9)
info["{}/{}/nmae".format(object_name, feature_name)] = nmae
# Moments
info["{}/{}/mean".format(object_name, feature_name)] = np.nanmean(array, axis=axis)
info["{}/{}/std".format(object_name, feature_name)] = np.nanstd(array, axis=axis)
# Quantiles
info["{}/{}/max".format(object_name, feature_name)] = np.nanmax(array, axis=axis)
info["{}/{}/90th".format(object_name, feature_name)] = np.nanpercentile(array, q=90, axis=axis)
info["{}/{}/75th".format(object_name, feature_name)] = np.nanpercentile(array, q=75, axis=axis)
info["{}/{}/50th".format(object_name, feature_name)] = np.nanpercentile(array, q=50, axis=axis)
info["{}/{}/25th".format(object_name, feature_name)] = np.nanpercentile(array, q=25, axis=axis)
info["{}/{}/10th".format(object_name, feature_name)] = np.nanpercentile(array, q=10, axis=axis)
info["{}/{}/min".format(object_name, feature_name)] = np.nanmin(array, axis=axis)
return info