The TrainingCallback
class in XGBoost provides a powerful way to customize the behavior of the training process by allowing users to define their own callbacks.
These callbacks can be used to implement custom logging, early stopping, or other modifications to the training loop. In this example, we’ll demonstrate how to use the TrainingCallback
class to log the training progress and implement early stopping.
To define a custom callback, we create a new class that inherits from TrainingCallback
and override its methods. The most commonly used method is after_iteration()
, which is called after each iteration (or round) of training. Here’s an example of how to define a custom callback:
import xgboost as xgb
class CustomCallback(xgboost.callback.TrainingCallback):
def __init__(self, early_stopping_rounds, validation_data):
super().__init__()
self.early_stopping_rounds = early_stopping_rounds
self.validation_data = validation_data
self.best_score = float('inf')
self.best_iteration = None
self.current_round = 0
def after_iteration(self, model, epoch, evals_log):
current_score = evals_log["validation"]["rmse"][-1]
self.current_round += 1
if current_score < self.best_score:
self.best_score = current_score
self.best_iteration = self.current_round
print(f"Round {self.current_round}: Best iteration = {self.best_iteration}, Best score = {self.best_score:.4f}")
if self.best_iteration is not None and (self.current_round - self.best_iteration >= self.early_stopping_rounds):
print(f"Early stopping at round {self.current_round}")
model.set_attr(best_iteration=str(self.best_iteration))
# stop training
return True
In this example, the CustomCallback
class is initialized with the number of rounds to wait before early stopping (early_stopping_rounds
) and the validation data (validation_data
). The after_iteration()
method is called after each training iteration and performs the following steps:
- Retrieves the current validation score from the
evals_log
. - Updates the
current_round
counter. - Checks if the current score is better than the best score seen so far and updates
best_score
andbest_iteration
accordingly. - Prints the current round, best iteration, and best score.
- Checks if the number of rounds since the best iteration exceeds the
early_stopping_rounds
threshold. If so, it prints a message, sets the best iteration attribute of the model, and returnsTrue
to stop training.
To use this custom callback during training, simply pass an instance of the CustomCallback
class to the callbacks
parameter of xgb.train()
:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
import xgboost
class CustomCallback(xgboost.callback.TrainingCallback):
def __init__(self, early_stopping_rounds, validation_data):
super().__init__()
self.early_stopping_rounds = early_stopping_rounds
self.validation_data = validation_data
self.best_score = float('inf')
self.best_iteration = None
self.current_round = 0
def after_iteration(self, model, epoch, evals_log):
current_score = evals_log["validation"]["rmse"][-1]
self.current_round += 1
if current_score < self.best_score:
self.best_score = current_score
self.best_iteration = self.current_round
print(f"Round {self.current_round}: Best iteration = {self.best_iteration}, Best score = {self.best_score:.4f}")
if self.best_iteration is not None and (self.current_round - self.best_iteration >= self.early_stopping_rounds):
print(f"Early stopping at round {self.current_round}")
model.set_attr(best_iteration=str(self.best_iteration))
# stop training
return True
# prepare dataset
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)
# split dataset into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# define model parameters
params = {
"objective": "reg:squarederror",
"learning_rate": 0.1,
"max_depth": 3,
"subsample": 0.8,
"colsample_bytree": 0.8,
}
# prepare data
dtrain = xgboost.DMatrix(X_train, label=y_train)
dval = xgboost.DMatrix(X_val, label=y_val)
# create callback
custom_callback = CustomCallback(early_stopping_rounds=10, validation_data=[(dtrain, "train"), (dval, "validation")])
# fit the model using the training callback
model = xgboost.train(
params,
dtrain,
num_boost_round=1000,
evals=[(dtrain, "train"), (dval, "validation")],
callbacks=[custom_callback],
)
y_pred = model.predict(dval)
rmse = root_mean_squared_error(y_val, y_pred)
print(f"Validation RMSE: {rmse:.4f}")
In this example, the CustomCallback
is instantiated with early_stopping_rounds=10
and the validation data. The callback is then passed to xgb.train()
via the callbacks
parameter. During training, the custom callback will log the progress and perform early stopping if the validation score doesn’t improve for 10 consecutive rounds.
By leveraging the TrainingCallback
class, users can easily customize the behavior of the XGBoost training process to suit their specific needs, such as implementing custom logging, early stopping, or other modifications to the training loop.