Note
Click here to download the full example code
Efficient Optimization Algorithms¶
Optuna enables efficient hyperparameter optimization by adopting state-of-the-art algorithms for sampling hyperparameters and pruning efficiently unpromising trials.
Sampling Algorithms¶
Samplers basically continually narrow down the search space using the records of suggested parameter values and evaluated objective values,
leading to an optimal search space which giving off parameters leading to better objective values.
More detailed explanation of how samplers suggest parameters is in optuna.samplers.BaseSampler
.
Optuna provides the following sampling algorithms:
Tree-structured Parzen Estimator algorithm implemented in
optuna.samplers.TPESampler
CMA-ES based algorithm implemented in
optuna.samplers.CmaEsSampler
Grid Search implemented in
optuna.samplers.GridSampler
Random Search implemented in
optuna.samplers.RandomSampler
The default sampler is optuna.samplers.TPESampler
.
Switching Samplers¶
import optuna
By default, Optuna uses TPESampler
as follows.
study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Out:
Sampler is TPESampler
If you want to use different samplers for example RandomSampler
and CmaEsSampler
,
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Out:
Sampler is RandomSampler
Sampler is CmaEsSampler
Pruning Algorithms¶
Pruners
automatically stop unpromising trials at the early stages of the training (a.k.a., automated early-stopping).
Optuna provides the following pruning algorithms:
Asynchronous Successive Halving algorithm implemted in
optuna.pruners.SuccessiveHalvingPruner
Hyperband algorithm implemented in
optuna.pruners.HyperbandPruner
Median pruning algorithm implemented in
optuna.pruners.MedianPruner
Threshold pruning algorithm implemented in
optuna.pruners.ThresholdPruner
We use optuna.pruners.MedianPruner
in most examples,
though basically it is outperformed by optuna.pruners.SuccessiveHalvingPruner
and
optuna.pruners.HyperbandPruner
as in this benchmark result.
Activating Pruners¶
To turn on the pruning feature, you need to call report()
and should_prune()
after each step of the iterative training.
report()
periodically monitors the intermediate objective values.
should_prune()
decides termination of the trial that does not meet a predefined condition.
We would recommend using integration modules for major machine learning frameworks.
Exclusive list is optuna.integration
and usecases are available in optuna/examples.
import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
iris.data, iris.target, test_size=0.25, random_state=0
)
alpha = trial.suggest_loguniform("alpha", 1e-5, 1e-1)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
Set up the median stopping rule as the pruning condition.
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
Out:
A new study created in memory with name: no-name-e703d80e-4bd9-429b-9850-2d4e17713e4a
Trial 0 finished with value: [0.1842105263157895] and parameters: {'alpha': 0.0019897925284271334}.
Trial 1 finished with value: [0.07894736842105265] and parameters: {'alpha': 0.03525073338974512}.
Trial 2 finished with value: [0.07894736842105265] and parameters: {'alpha': 0.0053404533493436055}.
Trial 3 finished with value: [0.21052631578947367] and parameters: {'alpha': 1.555865846009227e-05}.
Trial 4 finished with value: [0.1842105263157895] and parameters: {'alpha': 0.003852364228427479}.
Trial 5 pruned.
Trial 6 finished with value: [0.052631578947368474] and parameters: {'alpha': 3.5964079185803114e-05}.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 pruned.
Trial 12 pruned.
Trial 13 pruned.
Trial 14 pruned.
Trial 15 pruned.
Trial 16 pruned.
Trial 17 pruned.
Trial 18 pruned.
Trial 19 pruned.
As you can see, several trials were pruned (stopped) before they finished all of the iterations.
The format of message is "Trial <Trial Number> pruned."
.
Which Sampler and Pruner Should be Used?¶
From the benchmark results which are available at optuna/optuna - wiki “Benchmarks with Kurobako”, at least for not deep learning tasks, we would say that
For
optuna.samplers.RandomSampler
,optuna.pruners.MedianPruner
is the best.For
optuna.samplers.TPESampler
,optuna.pruners.Hyperband
is the best.
However, note that the benchmark is not deep learning. For deep learning tasks, consult the below table from Ozaki et al, Hyperparameter Optimization Methods: Overview and Characteristics, in IEICE Trans, Vol.J103-D No.9 pp.615-631, 2020,
Parallel Compute Resource |
Categorical/Conditional Hyperparameters |
Recommended Algorithms |
---|---|---|
Limited |
No |
TPE. GP-EI if search space is low-dimensional and continuous. |
Yes |
TPE. GP-EI if search space is low-dimensional and continuous |
|
Sufficient |
No |
CMA-ES, Random Search |
Yes |
Random Search or Genetic Algorithm |
Integration Modules for Pruning¶
To implement pruning mechanism in much simpler forms, Optuna provides integration modules for the following libraries.
For the complete list of Optuna’s integration modules, see optuna.integration
.
For example, XGBoostPruningCallback
introduces pruning without directly changing the logic of training iteration.
(See also example for the entire script.)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-error')
bst = xgb.train(param, dtrain, evals=[(dvalid, 'validation')], callbacks=[pruning_callback])
Total running time of the script: ( 0 minutes 1.456 seconds)