JaxGraphShape.to_numpy_shape

JaxGraphShape.to_numpy_shape()[source]

Convert a jax.numpy shape for GNN processing to a classical numpy shape.

This method transforms the internal JAX arrays of the shape back into standard NumPy arrays, enabling compatibility with non-JAX components.

Returns:

A classical GraphShape object with NumPy arrays.

Return type:

GraphShape