Coupler¶
- class Coupler(*args, **kwargs)[source]¶
Bases:
Module,ABCInterface for a coupler.
A coupler takes as input a graph and returns latent coordinates for each address. Graph information should be injected into the latent coordinates in a permutation-equivariant manner.
- Return type:
Any
Implementations¶
- class RecurrentCoupler(*args, **kwargs)[source]¶
Bases:
CouplerSimplified version of the Neural Ordinary Differential Equation solver.
The following recurrent system is used.:
\[\forall a \in \mathcal{A}_x, h_a(t+\delta t) = h_a(t+\delta t) + \delta t \times \phi_\theta(\psi^1_\theta(h;x)_a, \dots, \psi^n_\theta(h;x)_a),\]with the following initial condition:
\[\forall a \in \mathcal{A}_x, h_a(t=0) = [0, \dots, 0].\]- Parameters:
phi – Outer MLP \(\phi_\theta\).
message_functions – List of message functions \((\psi^i_\theta)_i\).
n_steps – Number of message passing steps.
- Return type:
Any
- class NODECoupler(*args, **kwargs)[source]¶
Bases:
CouplerOutput coordinates are computed by solving a Neural Ordinary Differential Equation.
The following ordinary differential equation is integrated between 0 and 1:
\[\forall a \in \mathcal{A}_x, \frac{dh_a}{dt} = \phi_\theta(\psi^1_\theta(h;x)_a, \dots, \psi^n_\theta(h;x)_a),\]with the following initial condition:
\[\forall a \in \mathcal{A}_x, h_a(t=0) = [0, \dots, 0].\]Implementation relies on Patrick Kidger’s Diffrax.
- Parameters:
phi – Outer MLP \(\phi_\theta\).
message_functions – List of message functions \((\psi^i_\theta)_i\).
latent_dimension – Dimension of address latent coordinates.
dt – Initial step size value.
stepsize_controller – Controller for adaptive step size methods.
adjoint – Method used for backpropagation.
solver – Numerical solver for the ODE.
max_steps – Maximum number of steps allowed for the solving of the ODE.
- Return type:
Any
Message Passing Functions¶
- class MessagePassingFunction(*args, **kwargs)[source]¶
Bases:
Module,ABCInterface for a message function \(\xi_\theta\) in a GNN message passing scheme.
- Return type:
Any
- class IdentityMessagePassingFunction(*args, **kwargs)[source]¶
Bases:
MessagePassingFunctionIdentity local message function module for GNN message passing.
This module returns the node features unchanged as the local message. It implements the identity mapping on node features:
\[h^\rightarrow_a = h_a\]- Return type:
Any
- class LocalSumMessagePassingFunction(*args, **kwargs)[source]¶
Bases:
MessagePassingFunctionLocal sum-based message function module for GNN message passing.
This module aggregates messages from each node’s local neighborhood by applying a class- and port-specific MLP \(\xi^{c,o}_\theta\) to hyper-edge features and neighbor coordinates, summing the results across all incoming ports, and applying a final activation \(\sigma\).
For each address \(a\), the output is defined as:
\[\psi_\theta(h,x)_a = \sigma \left( \sum_{(c,e,o)\in \mathcal{N}_x(a)} \xi^{c,o}_\theta(h_e, x_e)\right),\]where \(\xi^{c,o}_\theta\) is a class-specific and port-specific MLP, \(\sigma\) is an element-wise activation function, and \(h_e := (h_{o(e)})_{o \in {\mathcal{O}^c}}\) is the concatenation of port coordinates of hyper-edge \(e\).
- Parameters:
in_graph_structure – Input graph structure.
in_array_size – Size of the input coordinate arrays.
hidden_sizes – Hidden sizes of the MLPs \(\xi^{c,o}_\theta\).
activation – Activation function for the MLPs \(\xi^{c,o}_\theta\).
out_size – Output size of the MLPs \(\xi^{c,o}_\theta\).
use_bias – Whether to use bias in the MLPs \(\xi^{c,o}_\theta\).
kernel_init – Kernel initializer for the MLPs \(\xi^{c,o}_\theta\).
bias_init – Bias initializer for the MLPs \(\xi^{c,o}_\theta\).
final_activation – Final activation function for the MLPs \(\xi^{c,o}_\theta\).
outer_activation – Activation function \(\sigma\) applied over the output.
encoded_feature_size – None if the input data has not been encoded, otherwise the size of the encoded features.
port_scatter_blacklist – Dictionary mapping hyper-edge set keys to lists of port keys to be excluded from the sum.
seed – Seed for RNG streams for weight initialization.
- Return type:
Any