JaxHyperEdgeSet.from_numpy_hyper_edge_set

classmethod JaxHyperEdgeSet.from_numpy_hyper_edge_set(hyper_edge_set, device=None, dtype='float32')[source]

Convert a classical numpy hyper-edge set to a jax.numpy format for GNN processing.

This method transforms all array-like attributes of a HyperEdgeSet object into their JAX equivalents, allowing efficient use with JAX transformations and accelerators.

Parameters:
  • hyper_edge_set (HyperEdgeSet) – A hyper-edge set object containing NumPy arrays to convert.

  • device (Device | None) – Optional JAX device (e.g., CPU, GPU) to place the converted arrays on. If None, JAX uses the default device.

  • dtype (str) – Desired floating-point precision for converted arrays (e.g., “float32”, “float64”).

Returns:

A JAX-compatible version of the hyper-edge set, ready for use in GNN pipelines.

Return type:

JaxHyperEdgeSet