JaxGraph.tree_flatten

JaxGraph.tree_flatten()[source]

Flattens the JaxGraph for JAX PyTree compatibility.

Returns:

Flat children and auxiliary data (the keys order).