JaxGraphShape.tree_unflatten

classmethod JaxGraphShape.tree_unflatten(aux_data, children)[source]

Reconstruct a JaxGraphShape from flattened data, required for JAX compatibility.

Parameters:
  • aux_data – Sequence of keys matching the order of the children.

  • children – Sequence of array values.

Returns:

A reconstructed JaxGraphShape instance.

Return type:

JaxGraphShape