Trainer.train

Trainer.train(*, train_loader, val_loader=None, checkpoint_manager=None, n_epochs, tracker=None, log_period=None, eval_period=None, eval_before_training=False, eval_after_epoch=True, progress_bar=True, optim_mode='minimize')[source]

Trains the model over the train loader, periodically validates the model, tracks metrics, and checkpoints the model.

Parameters:
  • train_loader (ProblemLoader) – Problem loader used for training.

  • val_loader (ProblemLoader | None) – Problem loader used for validation.

  • checkpoint_manager (CheckpointManager | None) – Checkpoint manager for saving checkpoints.

  • n_epochs (int) – Number of training epochs to perform.

  • tracker (Tracker | None) – Experiment tracker.

  • log_period (int | None) – Number of training iterations between two logs, None for no logs.

  • eval_period (int | None) – Number of training epochs between two evaluations, None for no evaluations.

  • eval_before_training (bool) – If true, evaluate metrics over the full validation loader before training.

  • eval_after_epoch (bool) – If true, evaluate metrics over the full validation loader after each epoch.

  • progress_bar (bool) – If true, display a progress bar during training.

  • optim_mode (Literal['minimize', 'maximize']) – Optimization mode, either “minimize” or “maximize”. Overrides the checkpoint manager’s best_mode.

Returns:

Best average score obtained on the validation loader.

Return type:

float