JaxHyperEdgeSet.tree_unflatten

classmethod JaxHyperEdgeSet.tree_unflatten(aux_data, children)[source]

Unflattens a PyTree, required for JAX compatibility.

This method reconstructs an instance of the class from a flattened PyTree structure.

Parameters:
  • aux_data (Sequence[str]) – Tuple of keys originally returned by tree_flatten.

  • children (Sequence[Any]) – Sequence of values originally returned by tree_flatten.

Returns:

Reconstructed instance of the class (JaxHyperEdgeSet).

Raises:

KeyError – If the expected keys are missing in the zipped dictionary.

Return type:

JaxHyperEdgeSet