JaxGraph.tree_unflatten¶
- classmethod JaxGraph.tree_unflatten(aux_data, children)[source]¶
Reconstructs a JaxGraph from flattened data, required for JAX compatibility.
- Parameters:
aux_data – Sequence of keys matching the order of the children.
children – Sequence of array values.
- Returns:
A reconstructed JaxGraph instance.
- Return type: