Trainer¶
- class Trainer(*, model, gradient_transformation)[source]¶
Trainer implementation.
This basic trainer relies on the training of a permutation-equivariant Graph Neural Network \(\hat{y}_\theta\) over a dataset of problem instances. For a fixed problem instance with objective function \(f\) and context \(x\), the parameter \(\theta\) is updated according to the following gradient descent step,
\[\theta \gets \theta - \alpha . J_\theta[\hat{y}_\theta](x)^\top . \nabla_y f (\hat{y}_\theta(x);x),\]where \(J_\theta[\hat{y}_\theta]\) is the Jacobian matrix of the GNN \(\hat{y}_\theta\), and \(\nabla_y f\) is the gradient of the objective function \(f\) w.r.t the decision \(y\). For the sake of readability, a basic gradient descent is used – with a learning rate \(\alpha\) – but more complex optimizers are possible.
After every training epoch, the current trainer is checkpointed.
- Parameters:
model (GNN) – Core Graph Neural Network model.
gradient_transformation (optax.GradientTransformation) – Optax gradient transformation.
Trains the model over the train loader, periodically validates the model, tracks metrics, and checkpoints the model. |
|
Runs an evaluation and checkpoints. |
|
Saves the current model and optimizer state as a checkpoint. |
|
Loads a checkpoint from the checkpoint manager. |