JaxGraph.tree_unflatten

classmethod JaxGraph.tree_unflatten(aux_data, children)[source]

Reconstructs a JaxGraph from flattened data, required for JAX compatibility.

Parameters:
  • aux_data – Sequence of keys matching the order of the children.

  • children – Sequence of array values.

Returns:

A reconstructed JaxGraph instance.

Return type:

JaxGraph