JaxGraph.tree_flatten¶ JaxGraph.tree_flatten()[source]¶ Flattens the JaxGraph for JAX PyTree compatibility. Returns: Flat children and auxiliary data (the keys order).