JaxGraphShape.tree_unflatten¶
- classmethod JaxGraphShape.tree_unflatten(aux_data, children)[source]¶
Reconstruct a JaxGraphShape from flattened data, required for JAX compatibility.
- Parameters:
aux_data – Sequence of keys matching the order of the children.
children – Sequence of array values.
- Returns:
A reconstructed JaxGraphShape instance.
- Return type: