Graph

In this package, the classes Graph and JaxGraph are the core data representation. There are used to represent contexts \(x\) (i.e input data), decisions \(y\) (i.e. output data), and gradients \(\nabla_y f\). A Graph (resp JaxGraph) is composed of multiple HyperEdgeSet (resp JaxHyperEdgeSet), each defined by a series of ports and features.

The class Graph or JaxGraph can represent both a single graph instance or a batch of graphs.

Note

JaxGraph (resp JaxHyperEdgeSet, resp JaxGraphShape) is the Jax implementation of Graph (resp HyperEdgeSet, resp GraphShape) which is based on numpy. Here is a typical instance of Graph or JaxGraph.

>>> print(graph)
Mass
          ports      features
            node_id    weight         x         y         z
object_id
0               0.0  5.322265  0.202435  0.202435  0.242032
1               1.0  3.496568  0.962326  0.962326  0.306690
2               2.0  3.535864  0.060886  0.060886  0.094170
3               3.0  7.213709  0.984766  0.984766  0.068853
Spring
          ports              features
           node1_id node2_id         k
object_id
0               0.0      1.0  0.020424
1               1.0      2.0  0.037591
2               2.0      3.0  0.045405
Registry
[0. 1. 2. 3.]

Graph

class Graph(*, hyper_edge_sets, true_shape, current_shape, non_fictitious_addresses)[source]

Bases: dict

Hyper Heterogeneous Multi-Graph (H2MG) container.

Stores hyper-edge sets, shapes, and address masks for single or batched graphs.

Parameters:
  • hyper_edge_sets (dict[str, HyperEdgeSet]) – Dictionary of hyper-edge sets contained in the graph.

  • true_shape (GraphShape) – True shape of the graph, not altered by padding.

  • current_shape (GraphShape) – Current shape of the graph, consistent with padding.

  • non_fictitious_addresses (np.ndarray) – Mask filled with ones for real addresses, and zeros otherwise.

Graph.from_dict

Builds a graph from a dictionary of energnn.graph.HyperEdgeSet and a registry.

Graph.to_pickle

Saves a graph as a pickle file.

Graph.from_pickle

Loads a graph from a pickle file.

Graph.is_batch

Determines if the graph is batched.

Graph.is_single

Determines if the graph is single.

Graph.feature_flat_array

Returns an array that concatenates features of all hyper-edge sets.

Graph.pad

Pads hyper-edge sets and address mask to match target_shape.

Graph.unpad

Removes padding to restore true_shape.

Graph.count_connected_components

Counts connected components, and the component id of each address.

Graph.offset_addresses

Adds an offset on all addresses.

Graph.quantiles

Computes quantiles of hyper-edge set features.

class JaxGraph(*, hyper_edge_sets, true_shape, current_shape, non_fictitious_addresses)[source]

Bases: dict

Jax implementation of Hyper Heterogeneous Multi Graph (H2MG).

Stores hyper-edge sets, shapes, and address masks for single or batched graphs.

Parameters:
  • hyper_edge_sets (dict[str, JaxHyperEdgeSet]) – Dictionary of hyper-edge sets contained in the graph.

  • true_shape (JaxGraphShape) – True shape of the graph, not altered by padding.

  • current_shape (JaxGraphShape) – Current shape of the graph, consistent with padding.

  • non_fictitious_addresses (jax.Array) – Mask filled with ones for real addresses, and zeros otherwise.

JaxGraph.tree_flatten

Flattens the JaxGraph for JAX PyTree compatibility.

JaxGraph.tree_unflatten

Reconstructs a JaxGraph from flattened data, required for JAX compatibility.

JaxGraph.feature_flat_array

Returns an array that concatenates all hyper-edge set features.

JaxGraph.from_numpy_graph

Convert a classical numpy graph to a jax.numpy format for GNN processing.

JaxGraph.to_numpy_graph

Convert a jax.numpy graph for GNN processing to a classical numpy graph.

JaxGraph.quantiles

Computes quantiles of hyper-edge set features.

HyperEdgeSet

class HyperEdgeSet(*, port_dict, feature_array, feature_names, non_fictitious)[source]

Bases: dict

A collection of hyper-edges of the same class, optionally batched.

Internally this is just a dict storing four entries.

Parameters:
  • port_dict (dict[str, np.ndarray] | None) – Mapping from a port name to an array of shape (n_edges,) or (batch, n_edges).

  • feature_array (np.ndarray | None) – Array that contains all hyper-edge features.

  • feature_names (dict[str, int] | None) – Dictionary from feature names to index in feature_array.

  • non_fictitious (np.ndarray) – Mask array set to 1 for non-fictitious objects and to 0 for fictitious objects.

HyperEdgeSet.from_dict

Build a HyperEdgeSet from raw dicts of ports and features.

HyperEdgeSet.array

Concatenate (features, ports) along the last axis.

HyperEdgeSet.is_batch

True if array is 3-D: (batch, n_obj, features+ports).

HyperEdgeSet.is_single

True if array is 2-D: (n_obj, features+ports).

HyperEdgeSet.n_obj

Number of hyper-edges (objects) per instance.

HyperEdgeSet.n_batch

Number of batches.

HyperEdgeSet.port_array

Returns the stacked array of ports, of shape (n_obj, n_ports) or (batch, n_obj, n_ports).

HyperEdgeSet.port_names

Maps a port name to a column index in port_array.

HyperEdgeSet.feature_dict

Unstack feature_array into a dict: feature_name --> array.

HyperEdgeSet.feature_flat_array

Flatten all features into one long vector per (batch, ) by Fortran ordering.

HyperEdgeSet.pad

Pad a single HyperEdgeSet with a series of zeros for features and max-int for ports so that shapes match the target_shape.

HyperEdgeSet.unpad

Remove all objects beyond the index target in a single HyperEdgeSet.

HyperEdgeSet.offset_addresses

Adds an offset on all addresses.

class JaxHyperEdgeSet(*, port_dict, feature_array, feature_names, non_fictitious)[source]

Bases: dict

jax implementation of a collection of hyper-edges of the same class, optionally batched.

Internally this is just a dict storing four entries.

Parameters:
  • port_dict (dict[str, jax.Array] | None) – Dictionary that maps port names to address values.

  • feature_array (jax.Array | None) – Array that contains all hyper-edge features.

  • feature_names (dict[str, jax.Array] | None) – Dictionary from feature names to index in feature_array.

  • non_fictitious (jax.Array) – Binary mask filled with ones for non-fictitious objects.

JaxHyperEdgeSet.tree_flatten

Flattens a PyTree, required for JAX compatibility.

JaxHyperEdgeSet.tree_unflatten

Unflattens a PyTree, required for JAX compatibility.

JaxHyperEdgeSet.feature_flat_array

Returns a flat array by concatenating all features together.

JaxHyperEdgeSet.from_numpy_hyper_edge_set

Convert a classical numpy hyper-edge set to a jax.numpy format for GNN processing.

JaxHyperEdgeSet.to_numpy_hyper_edge_set

Convert a jax.numpy hyper-edge set for GNN processing to a classical numpy hyper-edge set.

GraphShape

class GraphShape(*, hyper_edge_sets, addresses)[source]

Bases: dict

Represents the shape of a graph, including counts of hyper-edge sets per class and registry size.

This class extends dict and maintains two keys: - HYPER_EDGE_SETS: dict mapping hyper-edge set class names to count arrays. - ADDRESSES: array representing the number of non-fictitious nodes.

Parameters:
  • hyper_edge_sets (dict[str, np.ndarray]) – Dictionary of that contains the number of objects for each class.

  • addresses (np.ndarray) – Number of addresses in the graph.

GraphShape.from_dict

Builds a new GraphShape object from a hyper-edge set dictionary and registry.

GraphShape.to_jsonable_dict

Serialize GraphShape to JSON-friendly dict.

GraphShape.from_jsonable_dict

Deserialize GraphShape from a JSON-friendly dictionary.

GraphShape.max

Returns the maximum shape of 2 graph shapes.

GraphShape.sum

Returns the sum shape of 2 graph shapes.

GraphShape.array

Concatenated hyper-edge set shapes as a single array.

GraphShape.is_single

True if the array is 1-D.

GraphShape.is_batch

True if the array is 2-D.

GraphShape.n_batch

Return the batch size.

class JaxGraphShape(*, hyper_edge_sets, addresses)[source]

Bases: dict

PyTree container for storing the number of objects in each class, and addresses in the graph.

This class inherits from dict and stores two keys: :param hyper_edge_sets: Dictionary of that contains the number of objects for each class. :param addresses: Number of addresses in the graph.

The PyTree methods tree_flatten and tree_unflatten make this object compatible with JAX transformations (jit, vmap, etc.).

JaxGraphShape.tree_flatten

Flatten the JaxGraphShape for JAX PyTree compatibility.

JaxGraphShape.tree_unflatten

Reconstruct a JaxGraphShape from flattened data, required for JAX compatibility.

JaxGraphShape.from_numpy_shape

Convert a classical numpy shape to a jax.numpy format for GNN processing.

JaxGraphShape.to_numpy_shape

Convert a jax.numpy shape for GNN processing to a classical numpy shape.

Graph, hyper-edge set, and shape manipulation functions

The following functions help to manipulate graphs, hyper-edge sets, shapes objects and to proceed operations on them.

collate_graphs

Collate a list of Graphs into a single Graph with padded shapes.

concatenate_graphs

Concatenates multiple graphs into a single graph.

get_statistics

Extract summary statistics from each feature array in the graph's hyper-edge sets.

separate_graphs

Split a batch of collated Graph into a list of single Graphs.

check_hyper_edge_set_dict_type

Validate that the provided mapping is a dictionary of HyperEdgeSet instances.

collate_hyper_edge_sets

Collate a list of HyperEdgeSet into a single batched HyperEdgeSet.

concatenate_hyper_edge_sets

Concatenate several single HyperEdgeSet into one single HyperEdgeSet.

separate_hyper_edge_sets

Separate a batched HyperEdgeSet into its constituent HyperEdgeSet instances.

check_dict_shape

Ensure all arrays in a dictionary have the same size on their last axis.

build_hyper_edge_set_shape

Builds a numpy array representing the number of hyper-edges.

dict2array

Stack a dictionary of arrays into a single array along the last axis.

check_dict_or_none

Validate that the input is either a dict or None.

check_no_nan

Ensure there are no NaN values in port or feature arrays.

collate_shapes

Batches a list of GraphShape into one batched GraphShape.

max_shape

Returns the maximum graph shape from a list of graph shapes.

separate_shapes

Splits a batched GraphShape into individual GraphShape instances.

sum_shapes

Returns the sum graph shape from a list of graph shapes.

to_numpy

Converts a NumPy array, JAX array, or tuple of values into a NumPy array (dtype float32), or converts the values in a dictionary accordingly.

np_to_jnp

Convert NumPy arrays or dictionary of NumPy arrays to JAX arrays.

jnp_to_np

Convert JAX arrays or mappings of JAX arrays back to NumPy arrays.