JaxHyperEdgeSet.tree_unflatten¶
- classmethod JaxHyperEdgeSet.tree_unflatten(aux_data, children)[source]¶
Unflattens a PyTree, required for JAX compatibility.
This method reconstructs an instance of the class from a flattened PyTree structure.
- Parameters:
aux_data (Sequence[str]) – Tuple of keys originally returned by tree_flatten.
children (Sequence[Any]) – Sequence of values originally returned by tree_flatten.
- Returns:
Reconstructed instance of the class (JaxHyperEdgeSet).
- Raises:
KeyError – If the expected keys are missing in the zipped dictionary.
- Return type: