GNN.forward_batch

GNN.forward_batch(*, graph, get_info=False)[source]

Applies the model to a batch of graphs.

Only the encoder, coupler, and decoder modules are vmapped, while the normalization module is not.

Parameters:
  • graph (JaxGraph) – Batch of input graphs.

  • get_info (bool) – Whether to return additional information about the processing steps.

Return type:

tuple[JaxGraph | Array, dict]