optuna.integration.PyTorchLightningPruningCallback

class optuna.integration.PyTorchLightningPruningCallback(trial, monitor)[source]

PyTorch Lightning callback to prune unpromising trials.

See the example if you want to add a pruning callback which observes accuracy.

Parameters:
  • trial (Trial) – A Trial corresponding to the current evaluation of the objective function.

  • monitor (str) – An evaluation metric for pruning, e.g., val_loss or val_acc. The metrics are obtained from the returned dictionaries from e.g. pytorch_lightning.LightningModule.training_step or pytorch_lightning.LightningModule.validation_epoch_end and the names thus depend on how this dictionary is formatted.

Note

For the distributed data parallel training, the version of PyTorchLightning needs to be higher than or equal to v1.6.0. In addition, Study should be instantiated with RDB storage.

Note

If you would like to use PyTorchLightningPruningCallback in a distributed training environment, you need to evoke PyTorchLightningPruningCallback.check_pruned() manually so that TrialPruned is properly handled.

Methods

check_pruned()

Raise optuna.TrialPruned manually if pruned.

on_fit_start(trainer, pl_module)

on_validation_end(trainer, pl_module)

check_pruned()[source]

Raise optuna.TrialPruned manually if pruned.

Currently, intermediate_values are not properly propagated between processes due to storage cache. Therefore, necessary information is kept in trial_system_attrs when the trial runs in a distributed situation. Please call this method right after calling pytorch_lightning.Trainer.fit(). If a callback doesn’t have any backend storage for DDP, this method does nothing.

Return type:

None