Trainer

class Trainer(*, model, gradient_transformation)[source]

Trainer implementation.

This basic trainer relies on the training of a permutation-equivariant Graph Neural Network \(\hat{y}_\theta\) over a dataset of problem instances. For a fixed problem instance with objective function \(f\) and context \(x\), the parameter \(\theta\) is updated according to the following gradient descent step,

\[\theta \gets \theta - \alpha . J_\theta[\hat{y}_\theta](x)^\top . \nabla_y f (\hat{y}_\theta(x);x),\]

where \(J_\theta[\hat{y}_\theta]\) is the Jacobian matrix of the GNN \(\hat{y}_\theta\), and \(\nabla_y f\) is the gradient of the objective function \(f\) w.r.t the decision \(y\). For the sake of readability, a basic gradient descent is used – with a learning rate \(\alpha\) – but more complex optimizers are possible.

After every training epoch, the current trainer is checkpointed.

Parameters:
  • model (GNN) – Core Graph Neural Network model.

  • gradient_transformation (optax.GradientTransformation) – Optax gradient transformation.

Trainer.train

Trains the model over the train loader, periodically validates the model, tracks metrics, and checkpoints the model.

Trainer.run_evaluation

Runs an evaluation and checkpoints.

Trainer.save_checkpoint

Saves the current model and optimizer state as a checkpoint.

Trainer.load_checkpoint

Loads a checkpoint from the checkpoint manager.