Graph¶
In this package, the classes Graph and JaxGraph are the core data representation.
There are used to represent contexts \(x\) (i.e input data),
decisions \(y\) (i.e. output data), and gradients \(\nabla_y f\).
A Graph (resp JaxGraph) is composed of multiple HyperEdgeSet (resp JaxHyperEdgeSet),
each defined by a series of ports and features.
The class Graph or JaxGraph can represent both a single graph instance or a batch of
graphs.
Note
JaxGraph (resp JaxHyperEdgeSet, resp JaxGraphShape) is the Jax implementation
of Graph (resp HyperEdgeSet, resp GraphShape) which is based on numpy.
Here is a typical instance of Graph or JaxGraph.
>>> print(graph)
Mass
ports features
node_id weight x y z
object_id
0 0.0 5.322265 0.202435 0.202435 0.242032
1 1.0 3.496568 0.962326 0.962326 0.306690
2 2.0 3.535864 0.060886 0.060886 0.094170
3 3.0 7.213709 0.984766 0.984766 0.068853
Spring
ports features
node1_id node2_id k
object_id
0 0.0 1.0 0.020424
1 1.0 2.0 0.037591
2 2.0 3.0 0.045405
Registry
[0. 1. 2. 3.]
Graph¶
- class Graph(*, hyper_edge_sets, true_shape, current_shape, non_fictitious_addresses)[source]¶
Bases:
dictHyper Heterogeneous Multi-Graph (H2MG) container.
Stores hyper-edge sets, shapes, and address masks for single or batched graphs.
- Parameters:
hyper_edge_sets (dict[str, HyperEdgeSet]) – Dictionary of hyper-edge sets contained in the graph.
true_shape (GraphShape) – True shape of the graph, not altered by padding.
current_shape (GraphShape) – Current shape of the graph, consistent with padding.
non_fictitious_addresses (np.ndarray) – Mask filled with ones for real addresses, and zeros otherwise.
Builds a graph from a dictionary of |
|
Saves a graph as a pickle file. |
|
Loads a graph from a pickle file. |
|
Determines if the graph is batched. |
|
Determines if the graph is single. |
|
Returns an array that concatenates features of all hyper-edge sets. |
|
Pads hyper-edge sets and address mask to match target_shape. |
|
Removes padding to restore true_shape. |
|
Counts connected components, and the component id of each address. |
|
Adds an offset on all addresses. |
|
Computes quantiles of hyper-edge set features. |
- class JaxGraph(*, hyper_edge_sets, true_shape, current_shape, non_fictitious_addresses)[source]¶
Bases:
dictJax implementation of Hyper Heterogeneous Multi Graph (H2MG).
Stores hyper-edge sets, shapes, and address masks for single or batched graphs.
- Parameters:
hyper_edge_sets (dict[str, JaxHyperEdgeSet]) – Dictionary of hyper-edge sets contained in the graph.
true_shape (JaxGraphShape) – True shape of the graph, not altered by padding.
current_shape (JaxGraphShape) – Current shape of the graph, consistent with padding.
non_fictitious_addresses (jax.Array) – Mask filled with ones for real addresses, and zeros otherwise.
Flattens the JaxGraph for JAX PyTree compatibility. |
|
Reconstructs a JaxGraph from flattened data, required for JAX compatibility. |
|
Returns an array that concatenates all hyper-edge set features. |
|
Convert a classical numpy graph to a jax.numpy format for GNN processing. |
|
Convert a jax.numpy graph for GNN processing to a classical numpy graph. |
|
Computes quantiles of hyper-edge set features. |
HyperEdgeSet¶
- class HyperEdgeSet(*, port_dict, feature_array, feature_names, non_fictitious)[source]¶
Bases:
dictA collection of hyper-edges of the same class, optionally batched.
Internally this is just a dict storing four entries.
- Parameters:
port_dict (dict[str, np.ndarray] | None) – Mapping from a port name to an array of shape (n_edges,) or (batch, n_edges).
feature_array (np.ndarray | None) – Array that contains all hyper-edge features.
feature_names (dict[str, int] | None) – Dictionary from feature names to index in feature_array.
non_fictitious (np.ndarray) – Mask array set to 1 for non-fictitious objects and to 0 for fictitious objects.
Build a HyperEdgeSet from raw dicts of ports and features. |
|
Concatenate (features, ports) along the last axis. |
|
True if array is 3-D: (batch, n_obj, features+ports). |
|
True if array is 2-D: (n_obj, features+ports). |
|
Number of hyper-edges (objects) per instance. |
|
Number of batches. |
|
Returns the stacked array of ports, of shape (n_obj, n_ports) or (batch, n_obj, n_ports). |
|
Maps a port name to a column index in port_array. |
|
Unstack feature_array into a dict: feature_name --> array. |
|
Flatten all features into one long vector per (batch, ) by Fortran ordering. |
|
Pad a single HyperEdgeSet with a series of zeros for features and max-int for ports so that shapes match the target_shape. |
|
Remove all objects beyond the index target in a single HyperEdgeSet. |
|
Adds an offset on all addresses. |
- class JaxHyperEdgeSet(*, port_dict, feature_array, feature_names, non_fictitious)[source]¶
Bases:
dictjax implementation of a collection of hyper-edges of the same class, optionally batched.
Internally this is just a dict storing four entries.
- Parameters:
port_dict (dict[str, jax.Array] | None) – Dictionary that maps port names to address values.
feature_array (jax.Array | None) – Array that contains all hyper-edge features.
feature_names (dict[str, jax.Array] | None) – Dictionary from feature names to index in feature_array.
non_fictitious (jax.Array) – Binary mask filled with ones for non-fictitious objects.
Flattens a PyTree, required for JAX compatibility. |
|
Unflattens a PyTree, required for JAX compatibility. |
|
Returns a flat array by concatenating all features together. |
|
Convert a classical numpy hyper-edge set to a jax.numpy format for GNN processing. |
|
Convert a jax.numpy hyper-edge set for GNN processing to a classical numpy hyper-edge set. |
GraphShape¶
- class GraphShape(*, hyper_edge_sets, addresses)[source]¶
Bases:
dictRepresents the shape of a graph, including counts of hyper-edge sets per class and registry size.
This class extends dict and maintains two keys: -
HYPER_EDGE_SETS: dict mapping hyper-edge set class names to count arrays. -ADDRESSES: array representing the number of non-fictitious nodes.- Parameters:
hyper_edge_sets (dict[str, np.ndarray]) – Dictionary of that contains the number of objects for each class.
addresses (np.ndarray) – Number of addresses in the graph.
Builds a new GraphShape object from a hyper-edge set dictionary and registry. |
|
Serialize GraphShape to JSON-friendly dict. |
|
Deserialize GraphShape from a JSON-friendly dictionary. |
|
Returns the maximum shape of 2 graph shapes. |
|
Returns the sum shape of 2 graph shapes. |
|
Concatenated hyper-edge set shapes as a single array. |
|
True if the array is 1-D. |
|
True if the array is 2-D. |
|
Return the batch size. |
- class JaxGraphShape(*, hyper_edge_sets, addresses)[source]¶
Bases:
dictPyTree container for storing the number of objects in each class, and addresses in the graph.
This class inherits from dict and stores two keys: :param hyper_edge_sets: Dictionary of that contains the number of objects for each class. :param addresses: Number of addresses in the graph.
The PyTree methods
tree_flattenandtree_unflattenmake this object compatible with JAX transformations (jit, vmap, etc.).
Flatten the JaxGraphShape for JAX PyTree compatibility. |
|
Reconstruct a JaxGraphShape from flattened data, required for JAX compatibility. |
|
Convert a classical numpy shape to a jax.numpy format for GNN processing. |
|
Convert a jax.numpy shape for GNN processing to a classical numpy shape. |
Graph, hyper-edge set, and shape manipulation functions¶
The following functions help to manipulate graphs, hyper-edge sets, shapes objects and to proceed operations on them.
Collate a list of Graphs into a single Graph with padded shapes. |
|
Concatenates multiple graphs into a single graph. |
|
Extract summary statistics from each feature array in the graph's hyper-edge sets. |
|
Split a batch of collated Graph into a list of single Graphs. |
|
Validate that the provided mapping is a dictionary of HyperEdgeSet instances. |
|
Collate a list of HyperEdgeSet into a single batched HyperEdgeSet. |
|
Concatenate several single HyperEdgeSet into one single HyperEdgeSet. |
|
Separate a batched HyperEdgeSet into its constituent HyperEdgeSet instances. |
|
Ensure all arrays in a dictionary have the same size on their last axis. |
|
Builds a numpy array representing the number of hyper-edges. |
|
Stack a dictionary of arrays into a single array along the last axis. |
|
Validate that the input is either a dict or None. |
|
Ensure there are no NaN values in port or feature arrays. |
|
Batches a list of GraphShape into one batched GraphShape. |
|
Returns the maximum graph shape from a list of graph shapes. |
|
Splits a batched GraphShape into individual GraphShape instances. |
|
Returns the sum graph shape from a list of graph shapes. |
|
Converts a NumPy array, JAX array, or tuple of values into a NumPy array (dtype float32), or converts the values in a dictionary accordingly. |
|
Convert NumPy arrays or dictionary of NumPy arrays to JAX arrays. |
|
Convert JAX arrays or mappings of JAX arrays back to NumPy arrays. |