JaxGraphShape.tree_flatten¶ JaxGraphShape.tree_flatten()[source]¶ Flatten the JaxGraphShape for JAX PyTree compatibility. Returns: Flat children and auxiliary data (the keys order).