import optuna
with optuna._imports.try_import() as _imports:
import xgboost as xgb
def _get_callback_context(env: "xgb.core.CallbackEnv") -> str:
"""Return whether the current callback context is cv or train.
.. note::
`Reference
<https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py>`_.
"""
if env.model is None and env.cvfolds is not None:
context = "cv"
else:
context = "train"
return context
[docs]class XGBoostPruningCallback(object):
"""Callback for XGBoost to prune unpromising trials.
See `the example <https://github.com/optuna/optuna/blob/master/
examples/pruning/xgboost_integration.py>`__
if you want to add a pruning callback which observes validation AUC of
a XGBoost model.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
observation_key:
An evaluation metric for pruning, e.g., ``validation-error`` and
``validation-merror``. When using the Scikit-Learn API, the index number of
``eval_set`` must be included in the ``observation_key``, e.g., ``validation_0-error``
and ``validation_0-merror``. Please refer to ``eval_metric`` in
`XGBoost reference <https://xgboost.readthedocs.io/en/latest/parameter.html>`_
for further details.
"""
def __init__(self, trial: optuna.trial.Trial, observation_key: str) -> None:
_imports.check()
self._trial = trial
self._observation_key = observation_key
def __call__(self, env: "xgb.core.CallbackEnv") -> None:
context = _get_callback_context(env)
evaluation_result_list = env.evaluation_result_list
if context == "cv":
# Remove a third element: the stddev of the metric across the cross-valdation folds.
evaluation_result_list = [(key, metric) for key, metric, _ in evaluation_result_list]
current_score = dict(evaluation_result_list)[self._observation_key]
self._trial.report(current_score, step=env.iteration)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(env.iteration)
raise optuna.TrialPruned(message)