Coupler

class Coupler(*args, **kwargs)[source]

Bases: Module, ABC

Interface for a coupler.

A coupler takes as input a graph and returns latent coordinates for each address. Graph information should be injected into the latent coordinates in a permutation-equivariant manner.

Return type:

Any

Implementations

class RecurrentCoupler(*args, **kwargs)[source]

Bases: Coupler

Simplified version of the Neural Ordinary Differential Equation solver.

The following recurrent system is used.:

\[\forall a \in \mathcal{A}_x, h_a(t+\delta t) = h_a(t+\delta t) + \delta t \times \phi_\theta(\psi^1_\theta(h;x)_a, \dots, \psi^n_\theta(h;x)_a),\]

with the following initial condition:

\[\forall a \in \mathcal{A}_x, h_a(t=0) = [0, \dots, 0].\]
Parameters:
  • phi – Outer MLP \(\phi_\theta\).

  • message_functions – List of message functions \((\psi^i_\theta)_i\).

  • n_steps – Number of message passing steps.

Return type:

Any

class NODECoupler(*args, **kwargs)[source]

Bases: Coupler

Output coordinates are computed by solving a Neural Ordinary Differential Equation.

The following ordinary differential equation is integrated between 0 and 1:

\[\forall a \in \mathcal{A}_x, \frac{dh_a}{dt} = \phi_\theta(\psi^1_\theta(h;x)_a, \dots, \psi^n_\theta(h;x)_a),\]

with the following initial condition:

\[\forall a \in \mathcal{A}_x, h_a(t=0) = [0, \dots, 0].\]

Implementation relies on Patrick Kidger’s Diffrax.

Parameters:
  • phi – Outer MLP \(\phi_\theta\).

  • message_functions – List of message functions \((\psi^i_\theta)_i\).

  • latent_dimension – Dimension of address latent coordinates.

  • dt – Initial step size value.

  • stepsize_controller – Controller for adaptive step size methods.

  • adjoint – Method used for backpropagation.

  • solver – Numerical solver for the ODE.

  • max_steps – Maximum number of steps allowed for the solving of the ODE.

Return type:

Any

Message Passing Functions

class MessagePassingFunction(*args, **kwargs)[source]

Bases: Module, ABC

Interface for a message function \(\xi_\theta\) in a GNN message passing scheme.

Return type:

Any

class IdentityMessagePassingFunction(*args, **kwargs)[source]

Bases: MessagePassingFunction

Identity local message function module for GNN message passing.

This module returns the node features unchanged as the local message. It implements the identity mapping on node features:

\[h^\rightarrow_a = h_a\]
Return type:

Any

class LocalSumMessagePassingFunction(*args, **kwargs)[source]

Bases: MessagePassingFunction

Local sum-based message function module for GNN message passing.

This module aggregates messages from each node’s local neighborhood by applying a class- and port-specific MLP \(\xi^{c,o}_\theta\) to hyper-edge features and neighbor coordinates, summing the results across all incoming ports, and applying a final activation \(\sigma\).

For each address \(a\), the output is defined as:

\[\psi_\theta(h,x)_a = \sigma \left( \sum_{(c,e,o)\in \mathcal{N}_x(a)} \xi^{c,o}_\theta(h_e, x_e)\right),\]

where \(\xi^{c,o}_\theta\) is a class-specific and port-specific MLP, \(\sigma\) is an element-wise activation function, and \(h_e := (h_{o(e)})_{o \in {\mathcal{O}^c}}\) is the concatenation of port coordinates of hyper-edge \(e\).

Parameters:
  • in_graph_structure – Input graph structure.

  • in_array_size – Size of the input coordinate arrays.

  • hidden_sizes – Hidden sizes of the MLPs \(\xi^{c,o}_\theta\).

  • activation – Activation function for the MLPs \(\xi^{c,o}_\theta\).

  • out_size – Output size of the MLPs \(\xi^{c,o}_\theta\).

  • use_bias – Whether to use bias in the MLPs \(\xi^{c,o}_\theta\).

  • kernel_init – Kernel initializer for the MLPs \(\xi^{c,o}_\theta\).

  • bias_init – Bias initializer for the MLPs \(\xi^{c,o}_\theta\).

  • final_activation – Final activation function for the MLPs \(\xi^{c,o}_\theta\).

  • outer_activation – Activation function \(\sigma\) applied over the output.

  • encoded_feature_size – None if the input data has not been encoded, otherwise the size of the encoded features.

  • port_scatter_blacklist – Dictionary mapping hyper-edge set keys to lists of port keys to be excluded from the sum.

  • seed – Seed for RNG streams for weight initialization.

Return type:

Any