JaxGraphShape.from_numpy_shape¶
- classmethod JaxGraphShape.from_numpy_shape(shape, device=None, dtype='float32')[source]¶
Convert a classical numpy shape to a jax.numpy format for GNN processing.
This method transforms all array-like attributes of a
GraphShapeobject into their JAX equivalents, allowing efficient use with JAX transformations and accelerators.- Parameters:
shape (GraphShape) – A shape object containing NumPy arrays to convert.
device (Device | None) – Optional JAX device (e.g., CPU, GPU) to place the converted arrays on. If None, JAX uses the default device.
dtype (str) – Desired floating-point precision for converted arrays (e.g., “float32”, “float64”).
- Returns:
A JAX-compatible version of the shape, ready for use in GNN pipelines.
- Return type: