JaxGraphShape.tree_flatten

JaxGraphShape.tree_flatten()[source]

Flatten the JaxGraphShape for JAX PyTree compatibility.

Returns:

Flat children and auxiliary data (the keys order).